mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b8323024c9 | |||
| d1768606c0 | |||
| ddd1c5e42f | |||
| a270e8b310 | |||
| f330ddce01 | |||
| b26b30ac3d | |||
| dae7c7870e | |||
| 4f56437030 | |||
| 42fd0cc22f | |||
| a4202028d9 | |||
| 4a0278420f | |||
| ade4a55cfe | |||
| 09872af36c | |||
| 92f562920d | |||
| 9d51e38641 | |||
| c966eb71a7 | |||
| c4368c9018 | |||
| f83767bb17 | |||
| 0e939bfe23 | |||
| 89da9b70db | |||
| a52deada8b | |||
| b7097baaec | |||
| 87200ff920 | |||
| 2d5f0787de | |||
| 5819bd8a59 | |||
| b3c2cc42cf | |||
| 167ef4512f | |||
| ba9cc5e972 | |||
| 6a94b58ad1 | |||
| d06643d8a2 | |||
| 92c185b90d | |||
| 9effa7be6d | |||
| 582bfda6f8 | |||
| 05ae4467ae | |||
| b66152c514 | |||
| 78fbc0abdb | |||
| ec5ed185cd | |||
| dbe3a3bb0d | |||
| 2b795265e7 | |||
| a57d05fe0a | |||
| ae9e8bc0bf |
@@ -10,7 +10,7 @@ permissions:
|
|||||||
contents: read
|
contents: read
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
lint:
|
lint-backend:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v6
|
- uses: actions/checkout@v6
|
||||||
|
|||||||
@@ -247,6 +247,9 @@ Access: http://localhost:2026
|
|||||||
|
|
||||||
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
The unified nginx endpoint is same-origin by default and does not emit browser CORS headers. If you run a split-origin or port-forwarded browser client, set `GATEWAY_CORS_ORIGINS` to comma-separated exact origins such as `http://localhost:3000`; the Gateway then applies the CORS allowlist and matching CSRF origin checks.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> The Gateway holds run state (RunManager and the stream bridge) in process, so production defaults to a single Gateway worker (`GATEWAY_WORKERS=1`). Raising the worker count without a shared cross-worker stream bridge — which is not yet available — breaks run cancellation, SSE reconnects, request de-duplication, and IM channels, because nginx uses no sticky sessions and each worker keeps its own run state. Scale a single worker up with more CPU/RAM (or move the database and sandbox onto dedicated tiers) instead of raising `GATEWAY_WORKERS`.
|
||||||
|
|
||||||
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
See [CONTRIBUTING.md](CONTRIBUTING.md) for detailed Docker development guide.
|
||||||
|
|
||||||
#### Option 2: Local Development
|
#### Option 2: Local Development
|
||||||
@@ -340,6 +343,8 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
|
|||||||
|
|
||||||
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
||||||
|
|
||||||
|
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, or WeCom from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
|
||||||
|
|
||||||
| Channel | Transport | Difficulty |
|
| Channel | Transport | Difficulty |
|
||||||
|---------|-----------|------------|
|
|---------|-----------|------------|
|
||||||
| Telegram | Bot API (long-polling) | Easy |
|
| Telegram | Bot API (long-polling) | Easy |
|
||||||
|
|||||||
+29
-10
@@ -369,8 +369,7 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
|||||||
|
|
||||||
### IM Channels System (`app/channels/`)
|
### IM Channels System (`app/channels/`)
|
||||||
|
|
||||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||||
|
|
||||||
|
|
||||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||||
|
|
||||||
@@ -380,18 +379,21 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||||
|
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
|
||||||
|
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
|
||||||
|
|
||||||
**Message Flow**:
|
**Message Flow**:
|
||||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
|
||||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
5. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
6. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
8. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||||
9. Outbound → channel callbacks → platform reply
|
9. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||||
|
10. Outbound → channel callbacks → platform reply
|
||||||
|
|
||||||
**Configuration** (`config.yaml` -> `channels`):
|
**Configuration** (`config.yaml` -> `channels`):
|
||||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
||||||
@@ -399,6 +401,16 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
||||||
|
|
||||||
|
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
|
||||||
|
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
|
||||||
|
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
|
||||||
|
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom use `/connect <code>` over their existing outbound channel workers.
|
||||||
|
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
|
||||||
|
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
|
||||||
|
- Slack replies use the configured operator bot token from `channels.slack` unless a future provider-token flow stores per-connection credentials.
|
||||||
|
- Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
||||||
|
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
||||||
|
|
||||||
|
|
||||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||||
|
|
||||||
@@ -429,6 +441,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the
|
|||||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||||
|
|
||||||
|
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
||||||
|
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
||||||
|
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
||||||
|
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
||||||
|
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
||||||
|
|
||||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||||
|
|
||||||
**Configuration** (`config.yaml` → `memory`):
|
**Configuration** (`config.yaml` → `memory`):
|
||||||
@@ -438,6 +456,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
|||||||
- `model_name` - LLM for updates (null = default model)
|
- `model_name` - LLM for updates (null = default model)
|
||||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||||
|
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
||||||
|
|
||||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||||
|
|
||||||
|
|||||||
@@ -20,6 +20,17 @@ KNOWN_CHANNEL_COMMANDS: frozenset[str] = frozenset(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_connect_code(text: str) -> str | None:
|
||||||
|
"""Extract the one-time channel binding code from a connect command."""
|
||||||
|
parts = text.strip().split()
|
||||||
|
if len(parts) < 2:
|
||||||
|
return None
|
||||||
|
command = parts[0].lower()
|
||||||
|
if command in {"/connect", "connect"}:
|
||||||
|
return parts[1]
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def is_known_channel_command(text: str) -> bool:
|
def is_known_channel_command(text: str) -> bool:
|
||||||
"""Return whether text starts with a registered channel control command."""
|
"""Return whether text starts with a registered channel control command."""
|
||||||
if not text.startswith("/"):
|
if not text.startswith("/"):
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
"""Helpers for attaching persisted channel connection ownership to inbound messages."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.channels.message_bus import InboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def attach_connection_identity(
|
||||||
|
inbound: InboundMessage,
|
||||||
|
*,
|
||||||
|
repo: Any,
|
||||||
|
provider: str,
|
||||||
|
workspace_id: str | None,
|
||||||
|
fallback_without_workspace: bool = False,
|
||||||
|
) -> InboundMessage:
|
||||||
|
"""Attach connection metadata to an inbound message when a persisted binding exists."""
|
||||||
|
if repo is None:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
workspace_candidates: list[str | None] = []
|
||||||
|
if workspace_id:
|
||||||
|
workspace_candidates.append(workspace_id)
|
||||||
|
if fallback_without_workspace:
|
||||||
|
workspace_candidates.append(None)
|
||||||
|
if not workspace_candidates:
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
for candidate in workspace_candidates:
|
||||||
|
connection = await repo.find_connection_by_external_identity(
|
||||||
|
provider=provider,
|
||||||
|
external_account_id=inbound.user_id,
|
||||||
|
workspace_id=candidate,
|
||||||
|
)
|
||||||
|
if connection is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
inbound.connection_id = connection["id"]
|
||||||
|
inbound.owner_user_id = connection["owner_user_id"]
|
||||||
|
inbound.workspace_id = connection.get("workspace_id")
|
||||||
|
return inbound
|
||||||
|
|
||||||
|
return inbound
|
||||||
@@ -14,7 +14,8 @@ from typing import Any
|
|||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
|
|||||||
self._incoming_messages: dict[str, Any] = {}
|
self._incoming_messages: dict[str, Any] = {}
|
||||||
self._incoming_messages_lock = threading.Lock()
|
self._incoming_messages_lock = threading.Lock()
|
||||||
self._card_repliers: dict[str, Any] = {}
|
self._card_repliers: dict[str, Any] = {}
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
|
|||||||
text[:100],
|
text[:100],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
if self._main_loop and self._main_loop.is_running():
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
conversation_type=conversation_type,
|
||||||
|
sender_staff_id=sender_staff_id,
|
||||||
|
sender_nick=sender_nick,
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._main_loop,
|
||||||
|
)
|
||||||
|
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||||
|
else:
|
||||||
|
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||||
|
return
|
||||||
|
|
||||||
if _is_dingtalk_command(text):
|
if _is_dingtalk_command(text):
|
||||||
msg_type = InboundMessageType.COMMAND
|
msg_type = InboundMessageType.COMMAND
|
||||||
else:
|
else:
|
||||||
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
# Running reply must finish before publish_inbound so AI card tracks are
|
# Running reply must finish before publish_inbound so AI card tracks are
|
||||||
# registered before the manager emits streaming outbounds.
|
# registered before the manager emits streaming outbounds.
|
||||||
await self._send_running_reply(chat_id, inbound)
|
await self._send_running_reply(chat_id, inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
||||||
|
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
||||||
|
return conversation_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
||||||
|
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="dingtalk",
|
||||||
|
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
conversation_type: str,
|
||||||
|
sender_staff_id: str,
|
||||||
|
sender_nick: str,
|
||||||
|
conversation_id: str,
|
||||||
|
code: str,
|
||||||
|
) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connection code is invalid or expired.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not sender_staff_id:
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connection could not be completed from this message.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="dingtalk",
|
||||||
|
external_account_id=sender_staff_id,
|
||||||
|
external_account_name=sender_nick or None,
|
||||||
|
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||||
|
metadata={
|
||||||
|
"conversation_type": conversation_type,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(
|
||||||
|
conversation_type,
|
||||||
|
sender_staff_id,
|
||||||
|
conversation_id,
|
||||||
|
"DingTalk connected to DeerFlow.",
|
||||||
|
)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(
|
||||||
|
self,
|
||||||
|
conversation_type: str,
|
||||||
|
sender_staff_id: str,
|
||||||
|
conversation_id: str,
|
||||||
|
text: str,
|
||||||
|
) -> None:
|
||||||
|
robot_code = self._client_id
|
||||||
|
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||||
|
if conversation_id:
|
||||||
|
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
||||||
|
return
|
||||||
|
if sender_staff_id:
|
||||||
|
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||||
|
|||||||
@@ -10,8 +10,9 @@ from pathlib import Path
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -70,6 +71,7 @@ class DiscordChannel(Channel):
|
|||||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._discord_module = None
|
self._discord_module = None
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -287,6 +289,10 @@ class DiscordChannel(Channel):
|
|||||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||||
|
return
|
||||||
|
|
||||||
# --- Determine thread/channel routing and typing target ---
|
# --- Determine thread/channel routing and typing target ---
|
||||||
thread_id = None
|
thread_id = None
|
||||||
chat_id = None
|
chat_id = None
|
||||||
@@ -315,6 +321,7 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||||
self._publish(inbound)
|
self._publish(inbound)
|
||||||
# Start typing indicator in the thread
|
# Start typing indicator in the thread
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -422,6 +429,7 @@ class DiscordChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = thread_id
|
inbound.topic_id = thread_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||||
|
|
||||||
# Start typing indicator in the correct target (thread or channel)
|
# Start typing indicator in the correct target (thread or channel)
|
||||||
if typing_target:
|
if typing_target:
|
||||||
@@ -436,6 +444,60 @@ class DiscordChannel(Channel):
|
|||||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="discord",
|
||||||
|
workspace_id=guild_id,
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
guild = getattr(message, "guild", None)
|
||||||
|
channel = getattr(message, "channel", None)
|
||||||
|
author = getattr(message, "author", None)
|
||||||
|
user_id = str(getattr(author, "id", "") or "")
|
||||||
|
if not user_id:
|
||||||
|
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
guild_id = str(getattr(guild, "id", "") or "") or None
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="discord",
|
||||||
|
external_account_id=user_id,
|
||||||
|
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
||||||
|
workspace_id=guild_id,
|
||||||
|
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
||||||
|
metadata={
|
||||||
|
"guild_id": guild_id,
|
||||||
|
"channel_id": str(getattr(channel, "id", "") or ""),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _send_connection_reply(message, text: str) -> None:
|
||||||
|
channel = getattr(message, "channel", None)
|
||||||
|
send = getattr(channel, "send", None)
|
||||||
|
if send is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await send(text)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Discord] failed to send connection reply")
|
||||||
|
|
||||||
def _run_client(self) -> None:
|
def _run_client(self) -> None:
|
||||||
self._discord_loop = asyncio.new_event_loop()
|
self._discord_loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self._discord_loop)
|
asyncio.set_event_loop(self._discord_loop)
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ import time
|
|||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import (
|
from app.channels.message_bus import (
|
||||||
PENDING_CLARIFICATION_METADATA_KEY,
|
PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||||
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
|
|||||||
self._CreateImageRequestBody = None
|
self._CreateImageRequestBody = None
|
||||||
self._GetMessageResourceRequest = None
|
self._GetMessageResourceRequest = None
|
||||||
self._thread_lock = threading.Lock()
|
self._thread_lock = threading.Lock()
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _non_empty_str(value: Any) -> str | None:
|
def _non_empty_str(value: Any) -> str | None:
|
||||||
@@ -86,6 +88,23 @@ class FeishuChannel(Channel):
|
|||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_running(self) -> bool:
|
||||||
|
if not self._running:
|
||||||
|
return False
|
||||||
|
return self._thread is not None and self._thread.is_alive()
|
||||||
|
|
||||||
|
def _build_event_handler(self, lark):
|
||||||
|
return (
|
||||||
|
lark.EventDispatcherHandler.builder("", "")
|
||||||
|
.register_p2_im_message_receive_v1(self._on_message)
|
||||||
|
.register_p2_im_message_message_read_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_reaction_created_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_reaction_deleted_v1(self._on_ignored_message_event)
|
||||||
|
.register_p2_im_message_recalled_v1(self._on_ignored_message_event)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
@@ -179,7 +198,7 @@ class FeishuChannel(Channel):
|
|||||||
# thread's uvloop.
|
# thread's uvloop.
|
||||||
_ws_client_mod.loop = loop
|
_ws_client_mod.loop = loop
|
||||||
|
|
||||||
event_handler = lark.EventDispatcherHandler.builder("", "").register_p2_im_message_receive_v1(self._on_message).build()
|
event_handler = self._build_event_handler(lark)
|
||||||
ws_client = lark.ws.Client(
|
ws_client = lark.ws.Client(
|
||||||
app_id=app_id,
|
app_id=app_id,
|
||||||
app_secret=app_secret,
|
app_secret=app_secret,
|
||||||
@@ -191,6 +210,10 @@ class FeishuChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
if self._running:
|
if self._running:
|
||||||
logger.exception("Feishu WebSocket error")
|
logger.exception("Feishu WebSocket error")
|
||||||
|
self._running = False
|
||||||
|
|
||||||
|
def _on_ignored_message_event(self, event) -> None:
|
||||||
|
logger.debug("[Feishu] ignoring non-content message event: %s", type(event).__name__)
|
||||||
|
|
||||||
async def stop(self) -> None:
|
async def stop(self) -> None:
|
||||||
self._running = False
|
self._running = False
|
||||||
@@ -726,11 +749,47 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||||
self._ensure_running_card_started(msg_id)
|
self._ensure_running_card_started(msg_id)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="feishu",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not user_id or not chat_id:
|
||||||
|
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="feishu",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=chat_id,
|
||||||
|
metadata={
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"message_id": message_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
def _on_message(self, event) -> None:
|
def _on_message(self, event) -> None:
|
||||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||||
try:
|
try:
|
||||||
@@ -819,6 +878,23 @@ class FeishuChannel(Channel):
|
|||||||
logger.info("[Feishu] empty text, ignoring message")
|
logger.info("[Feishu] empty text, ignoring message")
|
||||||
return
|
return
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
if self._main_loop and self._main_loop.is_running():
|
||||||
|
fut = asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
message_id=msg_id,
|
||||||
|
chat_id=chat_id,
|
||||||
|
user_id=sender_id,
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._main_loop,
|
||||||
|
)
|
||||||
|
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||||
|
else:
|
||||||
|
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
||||||
|
return
|
||||||
|
|
||||||
# Only treat known slash commands as commands; absolute paths and
|
# Only treat known slash commands as commands; absolute paths and
|
||||||
# other slash-prefixed text should be handled as normal chat.
|
# other slash-prefixed text should be handled as normal chat.
|
||||||
if _is_feishu_command(text):
|
if _is_feishu_command(text):
|
||||||
|
|||||||
+132
-30
@@ -274,6 +274,22 @@ def _response_metadata(base_metadata: dict[str, Any], *, pending_clarification:
|
|||||||
return metadata
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _thread_channel_metadata(msg: InboundMessage) -> dict[str, Any]:
|
||||||
|
channel_source: dict[str, Any] = {
|
||||||
|
"type": "im_channel",
|
||||||
|
"provider": msg.channel_name,
|
||||||
|
"chat_id": msg.chat_id,
|
||||||
|
}
|
||||||
|
if msg.topic_id:
|
||||||
|
channel_source["topic_id"] = msg.topic_id
|
||||||
|
if msg.thread_ts:
|
||||||
|
channel_source["thread_ts"] = msg.thread_ts
|
||||||
|
if msg.connection_id:
|
||||||
|
channel_source["connection_id"] = msg.connection_id
|
||||||
|
|
||||||
|
return {"channel_source": channel_source}
|
||||||
|
|
||||||
|
|
||||||
def _extract_text_content(content: Any) -> str:
|
def _extract_text_content(content: Any) -> str:
|
||||||
"""Extract text from a streaming payload content field."""
|
"""Extract text from a streaming payload content field."""
|
||||||
if isinstance(content, str):
|
if isinstance(content, str):
|
||||||
@@ -440,6 +456,33 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_disabled_owner_user_id() -> str | None:
|
||||||
|
try:
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to inspect auth-disabled mode for channel owner fallback", exc_info=True)
|
||||||
|
return None
|
||||||
|
return AUTH_DISABLED_USER_ID if is_auth_disabled() else None
|
||||||
|
|
||||||
|
|
||||||
|
def _effective_owner_user_id(msg: InboundMessage) -> str | None:
|
||||||
|
return _auth_disabled_owner_user_id() or msg.owner_user_id
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_effective_owner(msg: InboundMessage) -> InboundMessage:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if owner_user_id:
|
||||||
|
msg.owner_user_id = owner_user_id
|
||||||
|
return msg
|
||||||
|
|
||||||
|
|
||||||
|
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||||
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
return create_internal_auth_headers(owner_user_id=owner_user_id)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_slash_skill_command(
|
def _resolve_slash_skill_command(
|
||||||
text: str,
|
text: str,
|
||||||
available_skills: set[str] | None = None,
|
available_skills: set[str] | None = None,
|
||||||
@@ -670,6 +713,7 @@ class ChannelManager:
|
|||||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||||
default_session: dict[str, Any] | None = None,
|
default_session: dict[str, Any] | None = None,
|
||||||
channel_sessions: dict[str, Any] | None = None,
|
channel_sessions: dict[str, Any] | None = None,
|
||||||
|
connection_repo: Any | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.bus = bus
|
self.bus = bus
|
||||||
self.store = store
|
self.store = store
|
||||||
@@ -679,6 +723,7 @@ class ChannelManager:
|
|||||||
self._assistant_id = assistant_id
|
self._assistant_id = assistant_id
|
||||||
self._default_session = _as_dict(default_session)
|
self._default_session = _as_dict(default_session)
|
||||||
self._channel_sessions = dict(channel_sessions or {})
|
self._channel_sessions = dict(channel_sessions or {})
|
||||||
|
self._connection_repo = connection_repo
|
||||||
self._client = None # lazy init — langgraph_sdk async client
|
self._client = None # lazy init — langgraph_sdk async client
|
||||||
self._skill_storage: SkillStorage | None = None
|
self._skill_storage: SkillStorage | None = None
|
||||||
self._csrf_token = generate_csrf_token()
|
self._csrf_token = generate_csrf_token()
|
||||||
@@ -728,12 +773,17 @@ class ChannelManager:
|
|||||||
configurable["checkpoint_ns"] = ""
|
configurable["checkpoint_ns"] = ""
|
||||||
configurable["thread_id"] = thread_id
|
configurable["thread_id"] = thread_id
|
||||||
|
|
||||||
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||||
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||||
# under ``channel_user_id`` for platform-facing lookups.
|
# owns the connection. Preserve the raw platform user under
|
||||||
|
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||||
if msg.user_id:
|
owner_user_id = _effective_owner_user_id(msg)
|
||||||
|
if owner_user_id:
|
||||||
|
run_context_identity["user_id"] = make_safe_user_id(owner_user_id)
|
||||||
|
elif msg.user_id:
|
||||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||||
|
if msg.user_id:
|
||||||
run_context_identity["channel_user_id"] = msg.user_id
|
run_context_identity["channel_user_id"] = msg.user_id
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
@@ -845,6 +895,7 @@ class ChannelManager:
|
|||||||
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
logger.error("[Manager] unhandled error in message task: %s", exc, exc_info=exc)
|
||||||
|
|
||||||
async def _handle_message(self, msg: InboundMessage) -> None:
|
async def _handle_message(self, msg: InboundMessage) -> None:
|
||||||
|
msg = _apply_effective_owner(msg)
|
||||||
async with self._semaphore:
|
async with self._semaphore:
|
||||||
try:
|
try:
|
||||||
if msg.msg_type == InboundMessageType.COMMAND:
|
if msg.msg_type == InboundMessageType.COMMAND:
|
||||||
@@ -877,10 +928,27 @@ class ChannelManager:
|
|||||||
|
|
||||||
# -- chat handling -----------------------------------------------------
|
# -- chat handling -----------------------------------------------------
|
||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
if msg.connection_id and self._connection_repo is not None:
|
||||||
thread = await client.threads.create()
|
return await self._connection_repo.get_thread_id(
|
||||||
thread_id = thread["thread_id"]
|
msg.connection_id,
|
||||||
|
msg.chat_id,
|
||||||
|
msg.topic_id,
|
||||||
|
)
|
||||||
|
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||||
|
|
||||||
|
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||||
|
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||||
|
await self._connection_repo.set_thread_id(
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
|
provider=msg.channel_name,
|
||||||
|
external_conversation_id=msg.chat_id,
|
||||||
|
external_topic_id=msg.topic_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
self.store.set_thread_id(
|
self.store.set_thread_id(
|
||||||
msg.channel_name,
|
msg.channel_name,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
@@ -888,18 +956,40 @@ class ChannelManager:
|
|||||||
topic_id=msg.topic_id,
|
topic_id=msg.topic_id,
|
||||||
user_id=msg.user_id,
|
user_id=msg.user_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
|
"""Create a new thread through Gateway and store the mapping."""
|
||||||
|
metadata = _thread_channel_metadata(msg)
|
||||||
|
owner_headers = _owner_headers(msg)
|
||||||
|
if owner_headers:
|
||||||
|
thread = await client.threads.create(metadata=metadata, headers=owner_headers)
|
||||||
|
else:
|
||||||
|
thread = await client.threads.create(metadata=metadata)
|
||||||
|
thread_id = thread["thread_id"]
|
||||||
|
await self._store_thread_id(msg, thread_id)
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
return thread_id
|
return thread_id
|
||||||
|
|
||||||
|
async def _update_thread_channel_metadata(self, client, msg: InboundMessage, thread_id: str) -> None:
|
||||||
|
"""Best-effort source metadata backfill for existing IM-created threads."""
|
||||||
|
update_kwargs: dict[str, Any] = {"metadata": _thread_channel_metadata(msg)}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
update_kwargs["headers"] = owner_headers
|
||||||
|
try:
|
||||||
|
await client.threads.update(thread_id, **update_kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.debug("[Manager] failed to update channel metadata for thread_id=%s", thread_id, exc_info=True)
|
||||||
|
|
||||||
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
async def _handle_chat(self, msg: InboundMessage, extra_context: dict[str, Any] | None = None) -> None:
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
|
||||||
# Look up existing DeerFlow thread.
|
# Look up existing DeerFlow thread.
|
||||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
thread_id = await self._lookup_thread_id(msg)
|
||||||
if thread_id:
|
if thread_id:
|
||||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||||
|
await self._update_thread_channel_metadata(client, msg, thread_id)
|
||||||
|
|
||||||
# No existing thread found — create a new one
|
# No existing thread found — create a new one
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
@@ -940,14 +1030,19 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
|
run_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
run_kwargs["headers"] = owner_headers
|
||||||
try:
|
try:
|
||||||
result = await client.runs.wait(
|
result = await client.runs.wait(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [human_message]},
|
**run_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if _is_thread_busy_error(exc):
|
if _is_thread_busy_error(exc):
|
||||||
@@ -984,6 +1079,8 @@ class ChannelManager:
|
|||||||
artifacts=artifacts,
|
artifacts=artifacts,
|
||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||||
@@ -1008,16 +1105,21 @@ class ChannelManager:
|
|||||||
last_published_text = ""
|
last_published_text = ""
|
||||||
last_publish_at = 0.0
|
last_publish_at = 0.0
|
||||||
stream_error: BaseException | None = None
|
stream_error: BaseException | None = None
|
||||||
|
stream_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"stream_mode": ["messages-tuple", "values"],
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
stream_kwargs["headers"] = owner_headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in client.runs.stream(
|
async for chunk in client.runs.stream(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [human_message]},
|
**stream_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
stream_mode=["messages-tuple", "values"],
|
|
||||||
multitask_strategy="reject",
|
|
||||||
):
|
):
|
||||||
event = getattr(chunk, "event", "")
|
event = getattr(chunk, "event", "")
|
||||||
data = getattr(chunk, "data", None)
|
data = getattr(chunk, "data", None)
|
||||||
@@ -1047,6 +1149,8 @@ class ChannelManager:
|
|||||||
text=latest_text,
|
text=latest_text,
|
||||||
is_final=False,
|
is_final=False,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_response_metadata(msg.metadata),
|
metadata=_response_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1093,6 +1197,8 @@ class ChannelManager:
|
|||||||
attachments=attachments,
|
attachments=attachments,
|
||||||
is_final=True,
|
is_final=True,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -1124,18 +1230,10 @@ class ChannelManager:
|
|||||||
if reply is None and command == "new":
|
if reply is None and command == "new":
|
||||||
# Create a new thread through Gateway
|
# Create a new thread through Gateway
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
thread = await client.threads.create()
|
await self._create_thread(client, msg)
|
||||||
new_thread_id = thread["thread_id"]
|
|
||||||
self.store.set_thread_id(
|
|
||||||
msg.channel_name,
|
|
||||||
msg.chat_id,
|
|
||||||
new_thread_id,
|
|
||||||
topic_id=msg.topic_id,
|
|
||||||
user_id=msg.user_id,
|
|
||||||
)
|
|
||||||
reply = "New conversation started."
|
reply = "New conversation started."
|
||||||
elif reply is None and command == "status":
|
elif reply is None and command == "status":
|
||||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
thread_id = await self._lookup_thread_id(msg)
|
||||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||||
elif reply is None and command == "models":
|
elif reply is None and command == "models":
|
||||||
reply = await self._fetch_gateway("/api/models", "models")
|
reply = await self._fetch_gateway("/api/models", "models")
|
||||||
@@ -1174,9 +1272,11 @@ class ChannelManager:
|
|||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
thread_id=await self._lookup_thread_id(msg) or "",
|
||||||
text=reply,
|
text=reply,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
@@ -1212,9 +1312,11 @@ class ChannelManager:
|
|||||||
outbound = OutboundMessage(
|
outbound = OutboundMessage(
|
||||||
channel_name=msg.channel_name,
|
channel_name=msg.channel_name,
|
||||||
chat_id=msg.chat_id,
|
chat_id=msg.chat_id,
|
||||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
thread_id=await self._lookup_thread_id(msg) or "",
|
||||||
text=error_text,
|
text=error_text,
|
||||||
thread_ts=msg.thread_ts,
|
thread_ts=msg.thread_ts,
|
||||||
|
connection_id=msg.connection_id,
|
||||||
|
owner_user_id=msg.owner_user_id,
|
||||||
metadata=_slim_metadata(msg.metadata),
|
metadata=_slim_metadata(msg.metadata),
|
||||||
)
|
)
|
||||||
await self.bus.publish_outbound(outbound)
|
await self.bus.publish_outbound(outbound)
|
||||||
|
|||||||
@@ -44,6 +44,12 @@ class InboundMessage:
|
|||||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||||
reuse the same DeerFlow thread. When ``None``, each message
|
reuse the same DeerFlow thread. When ``None``, each message
|
||||||
creates a new thread (one-shot Q&A).
|
creates a new thread (one-shot Q&A).
|
||||||
|
connection_id: Optional DeerFlow channel connection id. When present,
|
||||||
|
conversation mapping is scoped by the connection instead of the
|
||||||
|
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||||
|
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||||
|
Platform user ids stay in ``user_id``.
|
||||||
|
workspace_id: Optional external workspace/guild/team id.
|
||||||
files: Optional list of file attachments (platform-specific dicts).
|
files: Optional list of file attachments (platform-specific dicts).
|
||||||
metadata: Arbitrary extra data from the channel.
|
metadata: Arbitrary extra data from the channel.
|
||||||
created_at: Unix timestamp when the message was created.
|
created_at: Unix timestamp when the message was created.
|
||||||
@@ -56,6 +62,9 @@ class InboundMessage:
|
|||||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
topic_id: str | None = None
|
topic_id: str | None = None
|
||||||
|
connection_id: str | None = None
|
||||||
|
owner_user_id: str | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
files: list[dict[str, Any]] = field(default_factory=list)
|
files: list[dict[str, Any]] = field(default_factory=list)
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
@@ -95,6 +104,9 @@ class OutboundMessage:
|
|||||||
is_final: Whether this is the final message in the response stream.
|
is_final: Whether this is the final message in the response stream.
|
||||||
thread_ts: Optional platform thread identifier for threaded replies.
|
thread_ts: Optional platform thread identifier for threaded replies.
|
||||||
metadata: Arbitrary extra data.
|
metadata: Arbitrary extra data.
|
||||||
|
connection_id: Optional DeerFlow channel connection id used for
|
||||||
|
connection-specific outbound credentials.
|
||||||
|
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||||
created_at: Unix timestamp.
|
created_at: Unix timestamp.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -106,6 +118,8 @@ class OutboundMessage:
|
|||||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||||
is_final: bool = True
|
is_final: bool = True
|
||||||
thread_ts: str | None = None
|
thread_ts: str | None = None
|
||||||
|
connection_id: str | None = None
|
||||||
|
owner_user_id: str | None = None
|
||||||
metadata: dict[str, Any] = field(default_factory=dict)
|
metadata: dict[str, Any] = field(default_factory=dict)
|
||||||
created_at: float = field(default_factory=time.time)
|
created_at: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,137 @@
|
|||||||
|
"""Local persistence for runtime IM channel configuration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import tempfile
|
||||||
|
import threading
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntimeConfigStore:
|
||||||
|
"""JSON-backed store for channel credentials entered from the UI.
|
||||||
|
|
||||||
|
This intentionally mirrors ``ChannelStore``: local/private deployments get
|
||||||
|
durable runtime configuration without needing a public callback URL or a
|
||||||
|
config.yaml edit.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: str | Path | None = None) -> None:
|
||||||
|
if path is None:
|
||||||
|
from deerflow.config.paths import get_paths
|
||||||
|
|
||||||
|
path = Path(get_paths().base_dir) / "channels" / "runtime-config.json"
|
||||||
|
self._path = Path(path)
|
||||||
|
self._path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._data: dict[str, dict[str, Any]] = self._load()
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def _load(self) -> dict[str, dict[str, Any]]:
|
||||||
|
if self._path.exists():
|
||||||
|
try:
|
||||||
|
raw = json.loads(self._path.read_text(encoding="utf-8"))
|
||||||
|
except (json.JSONDecodeError, OSError):
|
||||||
|
logger.warning("Corrupt channel runtime config store at %s, starting fresh", self._path)
|
||||||
|
return {}
|
||||||
|
if isinstance(raw, dict):
|
||||||
|
return {str(name): dict(value) for name, value in raw.items() if isinstance(value, dict)}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
def _save(self) -> None:
|
||||||
|
fd = tempfile.NamedTemporaryFile(
|
||||||
|
mode="w",
|
||||||
|
dir=self._path.parent,
|
||||||
|
suffix=".tmp",
|
||||||
|
delete=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
json.dump(self._data, fd, indent=2, ensure_ascii=False)
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).replace(self._path)
|
||||||
|
try:
|
||||||
|
self._path.chmod(0o600)
|
||||||
|
except OSError:
|
||||||
|
logger.debug("Unable to chmod channel runtime config store at %s", self._path, exc_info=True)
|
||||||
|
except BaseException:
|
||||||
|
fd.close()
|
||||||
|
Path(fd.name).unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def load_all(self) -> dict[str, dict[str, Any]]:
|
||||||
|
with self._lock:
|
||||||
|
return {name: dict(config) for name, config in self._data.items()}
|
||||||
|
|
||||||
|
def get_provider_config(self, provider: str) -> dict[str, Any] | None:
|
||||||
|
with self._lock:
|
||||||
|
config = self._data.get(provider)
|
||||||
|
return dict(config) if isinstance(config, dict) else None
|
||||||
|
|
||||||
|
def set_provider_config(self, provider: str, config: dict[str, Any]) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._data[provider] = dict(config)
|
||||||
|
self._save()
|
||||||
|
|
||||||
|
def remove_provider_config(self, provider: str) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
if provider not in self._data:
|
||||||
|
return False
|
||||||
|
del self._data[provider]
|
||||||
|
self._save()
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_enabled(channel_connections_config: Any, provider: str) -> bool:
|
||||||
|
provider_config = getattr(channel_connections_config, provider, None)
|
||||||
|
return bool(getattr(provider_config, "enabled", False))
|
||||||
|
|
||||||
|
|
||||||
|
def merge_runtime_channel_configs(
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Merge persisted runtime provider config into ``channels_config`` in-place."""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
for provider, runtime_config in runtime_store.load_all().items():
|
||||||
|
if not _provider_enabled(channel_connections_config, provider):
|
||||||
|
continue
|
||||||
|
existing = channels_config.get(provider)
|
||||||
|
merged = dict(runtime_config)
|
||||||
|
if isinstance(existing, dict):
|
||||||
|
merged.update(existing)
|
||||||
|
channels_config[provider] = merged
|
||||||
|
|
||||||
|
|
||||||
|
def apply_runtime_connection_config(
|
||||||
|
channel_connections_config: Any,
|
||||||
|
*,
|
||||||
|
store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Apply persisted connection metadata that lives outside ``channels``.
|
||||||
|
|
||||||
|
Telegram uses a bot username for deep links; UI-entered values are stored
|
||||||
|
with the runtime channel config so local restarts keep the provider
|
||||||
|
configured.
|
||||||
|
"""
|
||||||
|
if channel_connections_config is None or not getattr(channel_connections_config, "enabled", False):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
runtime_store = store or ChannelRuntimeConfigStore()
|
||||||
|
telegram_runtime_config = runtime_store.get_provider_config("telegram")
|
||||||
|
bot_username = ""
|
||||||
|
if isinstance(telegram_runtime_config, dict):
|
||||||
|
bot_username = str(telegram_runtime_config.get("bot_username") or "").strip()
|
||||||
|
if not bot_username or not _provider_enabled(channel_connections_config, "telegram"):
|
||||||
|
return channel_connections_config
|
||||||
|
|
||||||
|
config = channel_connections_config.model_copy(deep=True)
|
||||||
|
config.telegram.bot_username = bot_username
|
||||||
|
return config
|
||||||
@@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
from app.channels.manager import DEFAULT_GATEWAY_URL, DEFAULT_LANGGRAPH_URL, ChannelManager
|
||||||
from app.channels.message_bus import MessageBus
|
from app.channels.message_bus import MessageBus
|
||||||
|
from app.channels.runtime_config_store import merge_runtime_channel_configs
|
||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -52,6 +53,30 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
|||||||
return default
|
return default
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||||
|
connection_config = getattr(app_config, "channel_connections", None)
|
||||||
|
merge_runtime_channel_configs(channels_config, connection_config)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connection_repo(app_config: AppConfig):
|
||||||
|
connection_config = getattr(app_config, "channel_connections", None)
|
||||||
|
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel connection repository")
|
||||||
|
return None
|
||||||
|
|
||||||
|
session_factory = get_session_factory()
|
||||||
|
if session_factory is None:
|
||||||
|
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||||
|
return None
|
||||||
|
return ChannelConnectionRepository(session_factory)
|
||||||
|
|
||||||
|
|
||||||
class ChannelService:
|
class ChannelService:
|
||||||
"""Manages the lifecycle of all configured IM channels.
|
"""Manages the lifecycle of all configured IM channels.
|
||||||
|
|
||||||
@@ -59,9 +84,10 @@ class ChannelService:
|
|||||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||||
self.bus = MessageBus()
|
self.bus = MessageBus()
|
||||||
self.store = ChannelStore()
|
self.store = ChannelStore()
|
||||||
|
self._connection_repo = connection_repo
|
||||||
config = dict(channels_config or {})
|
config = dict(channels_config or {})
|
||||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||||
@@ -74,6 +100,7 @@ class ChannelService:
|
|||||||
gateway_url=gateway_url,
|
gateway_url=gateway_url,
|
||||||
default_session=default_session if isinstance(default_session, dict) else None,
|
default_session=default_session if isinstance(default_session, dict) else None,
|
||||||
channel_sessions=channel_sessions,
|
channel_sessions=channel_sessions,
|
||||||
|
connection_repo=connection_repo,
|
||||||
)
|
)
|
||||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||||
self._config = config
|
self._config = config
|
||||||
@@ -90,8 +117,9 @@ class ChannelService:
|
|||||||
# extra fields are allowed by AppConfig (extra="allow")
|
# extra fields are allowed by AppConfig (extra="allow")
|
||||||
extra = app_config.model_extra or {}
|
extra = app_config.model_extra or {}
|
||||||
if "channels" in extra:
|
if "channels" in extra:
|
||||||
channels_config = extra["channels"]
|
channels_config = dict(extra["channels"] or {})
|
||||||
return cls(channels_config=channels_config)
|
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||||
|
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
"""Start the manager and all enabled channels."""
|
"""Start the manager and all enabled channels."""
|
||||||
@@ -151,6 +179,27 @@ class ChannelService:
|
|||||||
|
|
||||||
return await self._start_channel(name, config)
|
return await self._start_channel(name, config)
|
||||||
|
|
||||||
|
async def configure_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
|
"""Apply runtime config for a channel and restart it if the service is running."""
|
||||||
|
self._config[name] = dict(config)
|
||||||
|
if not self._running:
|
||||||
|
return True
|
||||||
|
return await self.restart_channel(name)
|
||||||
|
|
||||||
|
async def remove_channel(self, name: str) -> bool:
|
||||||
|
"""Remove runtime config for a channel and stop it if currently running."""
|
||||||
|
self._config.pop(name, None)
|
||||||
|
channel = self._channels.pop(name, None)
|
||||||
|
if channel is None:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
await channel.stop()
|
||||||
|
logger.info("Channel %s stopped and removed", name)
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error stopping channel %s for removal", name)
|
||||||
|
return False
|
||||||
|
|
||||||
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
async def _start_channel(self, name: str, config: dict[str, Any]) -> bool:
|
||||||
"""Instantiate and start a single channel."""
|
"""Instantiate and start a single channel."""
|
||||||
import_path = _CHANNEL_REGISTRY.get(name)
|
import_path = _CHANNEL_REGISTRY.get(name)
|
||||||
@@ -169,6 +218,8 @@ class ChannelService:
|
|||||||
try:
|
try:
|
||||||
config = dict(config)
|
config = dict(config)
|
||||||
config["channel_store"] = self.store
|
config["channel_store"] = self.store
|
||||||
|
if self._connection_repo is not None:
|
||||||
|
config["connection_repo"] = self._connection_repo
|
||||||
channel = channel_cls(bus=self.bus, config=config)
|
channel = channel_cls(bus=self.bus, config=config)
|
||||||
self._channels[name] = channel
|
self._channels[name] = channel
|
||||||
await channel.start()
|
await channel.start()
|
||||||
|
|||||||
+140
-26
@@ -9,7 +9,8 @@ from typing import Any
|
|||||||
from markdown_to_mrkdwn import SlackMarkdownConverter
|
from markdown_to_mrkdwn import SlackMarkdownConverter
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -64,6 +65,8 @@ class SlackChannel(Channel):
|
|||||||
self._web_client = None
|
self._web_client = None
|
||||||
self._loop: asyncio.AbstractEventLoop | None = None
|
self._loop: asyncio.AbstractEventLoop | None = None
|
||||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
self._web_client_factory = config.get("web_client_factory")
|
||||||
configured_bot_user_id = config.get("bot_user_id")
|
configured_bot_user_id = config.get("bot_user_id")
|
||||||
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
||||||
|
|
||||||
@@ -80,26 +83,28 @@ class SlackChannel(Channel):
|
|||||||
return
|
return
|
||||||
|
|
||||||
self._SocketModeResponse = SocketModeResponse
|
self._SocketModeResponse = SocketModeResponse
|
||||||
|
if self._web_client_factory is None:
|
||||||
|
self._web_client_factory = WebClient
|
||||||
|
|
||||||
bot_token = self.config.get("bot_token", "")
|
bot_token = self.config.get("bot_token", "")
|
||||||
app_token = self.config.get("app_token", "")
|
app_token = self.config.get("app_token", "")
|
||||||
|
|
||||||
|
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||||
|
if not bot_token:
|
||||||
|
logger.error("Slack HTTP Events mode requires bot_token")
|
||||||
|
return
|
||||||
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
|
self._loop = asyncio.get_event_loop()
|
||||||
|
self._running = True
|
||||||
|
self.bus.subscribe_outbound(self._on_outbound)
|
||||||
|
logger.info("Slack channel started in HTTP Events mode")
|
||||||
|
return
|
||||||
|
|
||||||
if not bot_token or not app_token:
|
if not bot_token or not app_token:
|
||||||
logger.error("Slack channel requires bot_token and app_token")
|
logger.error("Slack channel requires bot_token and app_token")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._web_client = WebClient(token=bot_token)
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
if self._bot_user_id is None:
|
|
||||||
try:
|
|
||||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
|
||||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
|
||||||
if user_id is None:
|
|
||||||
auth_get = getattr(auth_info, "get", None)
|
|
||||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
|
||||||
if isinstance(user_id, str) and user_id:
|
|
||||||
self._bot_user_id = user_id
|
|
||||||
except Exception:
|
|
||||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
|
||||||
self._socket_client = SocketModeClient(
|
self._socket_client = SocketModeClient(
|
||||||
app_token=app_token,
|
app_token=app_token,
|
||||||
web_client=self._web_client,
|
web_client=self._web_client,
|
||||||
@@ -124,7 +129,8 @@ class SlackChannel(Channel):
|
|||||||
logger.info("Slack channel stopped")
|
logger.info("Slack channel stopped")
|
||||||
|
|
||||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
if not self._web_client:
|
web_client = await self._get_web_client_for_message(msg)
|
||||||
|
if not web_client:
|
||||||
return
|
return
|
||||||
|
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
@@ -137,11 +143,12 @@ class SlackChannel(Channel):
|
|||||||
last_exc: Exception | None = None
|
last_exc: Exception | None = None
|
||||||
for attempt in range(_max_retries):
|
for attempt in range(_max_retries):
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||||
# Add a completion reaction to the thread root
|
# Add a completion reaction to the thread root
|
||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction,
|
self._add_reaction_with_client,
|
||||||
|
web_client,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"white_check_mark",
|
"white_check_mark",
|
||||||
@@ -165,7 +172,8 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
try:
|
try:
|
||||||
await asyncio.to_thread(
|
await asyncio.to_thread(
|
||||||
self._add_reaction,
|
self._add_reaction_with_client,
|
||||||
|
web_client,
|
||||||
msg.chat_id,
|
msg.chat_id,
|
||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"x",
|
"x",
|
||||||
@@ -177,7 +185,8 @@ class SlackChannel(Channel):
|
|||||||
raise last_exc
|
raise last_exc
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
if not self._web_client:
|
web_client = await self._get_web_client_for_message(msg)
|
||||||
|
if not web_client:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -190,7 +199,7 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
kwargs["thread_ts"] = msg.thread_ts
|
kwargs["thread_ts"] = msg.thread_ts
|
||||||
|
|
||||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||||
return True
|
return True
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -199,12 +208,38 @@ class SlackChannel(Channel):
|
|||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
# -- internal ----------------------------------------------------------
|
||||||
|
|
||||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
self._web_client = self._web_client_factory(token=bot_token)
|
||||||
if not self._web_client:
|
if self._bot_user_id is not None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
self._web_client.reactions_add(
|
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||||
|
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||||
|
if user_id is None:
|
||||||
|
auth_get = getattr(auth_info, "get", None)
|
||||||
|
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||||
|
if isinstance(user_id, str) and user_id:
|
||||||
|
self._bot_user_id = user_id
|
||||||
|
except Exception:
|
||||||
|
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||||
|
|
||||||
|
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||||
|
if msg.connection_id and self._connection_repo is not None:
|
||||||
|
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||||
|
access_token = credentials.get("access_token") if credentials else None
|
||||||
|
if not access_token:
|
||||||
|
return self._web_client
|
||||||
|
if self._web_client_factory is None:
|
||||||
|
from slack_sdk import WebClient
|
||||||
|
|
||||||
|
self._web_client_factory = WebClient
|
||||||
|
return self._web_client_factory(token=access_token)
|
||||||
|
return self._web_client
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||||
|
try:
|
||||||
|
web_client.reactions_add(
|
||||||
channel=channel_id,
|
channel=channel_id,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
name=emoji,
|
name=emoji,
|
||||||
@@ -213,6 +248,12 @@ class SlackChannel(Channel):
|
|||||||
if "already_reacted" not in str(exc):
|
if "already_reacted" not in str(exc):
|
||||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||||
|
|
||||||
|
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||||
|
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||||
|
if not self._web_client:
|
||||||
|
return
|
||||||
|
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||||
|
|
||||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||||
if not self._web_client:
|
if not self._web_client:
|
||||||
@@ -249,12 +290,15 @@ class SlackChannel(Channel):
|
|||||||
|
|
||||||
# Handle message events (DM or @mention)
|
# Handle message events (DM or @mention)
|
||||||
if etype in ("message", "app_mention"):
|
if etype in ("message", "app_mention"):
|
||||||
self._handle_message_event(event)
|
self._handle_message_event(
|
||||||
|
event,
|
||||||
|
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Error processing Slack event")
|
logger.exception("Error processing Slack event")
|
||||||
|
|
||||||
def _handle_message_event(self, event: dict) -> None:
|
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
||||||
# Ignore bot messages
|
# Ignore bot messages
|
||||||
if event.get("bot_id") or event.get("subtype"):
|
if event.get("bot_id") or event.get("subtype"):
|
||||||
return
|
return
|
||||||
@@ -272,6 +316,19 @@ class SlackChannel(Channel):
|
|||||||
if not text:
|
if not text:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code:
|
||||||
|
if self._loop and self._loop.is_running():
|
||||||
|
asyncio.run_coroutine_threadsafe(
|
||||||
|
self._bind_connection_from_connect_code(
|
||||||
|
event=event,
|
||||||
|
team_id=str(team_id or event.get("team") or ""),
|
||||||
|
code=connect_code,
|
||||||
|
),
|
||||||
|
self._loop,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
channel_id = event.get("channel", "")
|
channel_id = event.get("channel", "")
|
||||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||||
|
|
||||||
@@ -297,4 +354,61 @@ class SlackChannel(Channel):
|
|||||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||||
self._send_running_reply(channel_id, thread_ts)
|
self._send_running_reply(channel_id, thread_ts)
|
||||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
if self._connection_repo is None:
|
||||||
|
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||||
|
else:
|
||||||
|
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
||||||
|
|
||||||
|
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
||||||
|
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
||||||
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||||
|
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="slack",
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
channel_id = str(event.get("channel") or "")
|
||||||
|
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
||||||
|
if state is None:
|
||||||
|
self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
user_id = str(event.get("user") or "")
|
||||||
|
if not user_id or not team_id:
|
||||||
|
self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="slack",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=team_id,
|
||||||
|
metadata={
|
||||||
|
"team_id": team_id,
|
||||||
|
"channel_id": channel_id,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
||||||
|
if not self._web_client or not channel_id:
|
||||||
|
return
|
||||||
|
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
||||||
|
if thread_ts:
|
||||||
|
kwargs["thread_ts"] = thread_ts
|
||||||
|
try:
|
||||||
|
self._web_client.chat_postMessage(**kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import threading
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -35,6 +36,7 @@ class TelegramChannel(Channel):
|
|||||||
pass
|
pass
|
||||||
# chat_id -> last sent message_id for threaded replies
|
# chat_id -> last sent message_id for threaded replies
|
||||||
self._last_bot_message: dict[str, int] = {}
|
self._last_bot_message: dict[str, int] = {}
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
if self._running:
|
if self._running:
|
||||||
@@ -176,6 +178,26 @@ class TelegramChannel(Channel):
|
|||||||
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
|
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
async def process_webhook_update(self, payload: dict[str, Any]) -> bool:
|
||||||
|
if not self._application:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
from telegram import Update
|
||||||
|
except ImportError:
|
||||||
|
logger.error("python-telegram-bot is not installed. Install it with: uv add python-telegram-bot")
|
||||||
|
return False
|
||||||
|
|
||||||
|
update = Update.de_json(payload, self._application.bot)
|
||||||
|
if update is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self._tg_loop and self._tg_loop.is_running():
|
||||||
|
future = asyncio.run_coroutine_threadsafe(self._application.process_update(update), self._tg_loop)
|
||||||
|
await asyncio.wrap_future(future)
|
||||||
|
else:
|
||||||
|
await self._application.process_update(update)
|
||||||
|
return True
|
||||||
|
|
||||||
# -- helpers -----------------------------------------------------------
|
# -- helpers -----------------------------------------------------------
|
||||||
|
|
||||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||||
@@ -233,6 +255,54 @@ class TelegramChannel(Channel):
|
|||||||
return True
|
return True
|
||||||
return user_id in self._allowed_users
|
return user_id in self._allowed_users
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _telegram_display_name(user) -> str:
|
||||||
|
full_name = getattr(user, "full_name", None)
|
||||||
|
if isinstance(full_name, str) and full_name:
|
||||||
|
return full_name
|
||||||
|
username = getattr(user, "username", None)
|
||||||
|
if isinstance(username, str) and username:
|
||||||
|
return username
|
||||||
|
return str(getattr(user, "id", ""))
|
||||||
|
|
||||||
|
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||||
|
if self._connection_repo is None or not state_token:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||||
|
if state is None:
|
||||||
|
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
owner_user_id = state["owner_user_id"]
|
||||||
|
user_id = str(update.effective_user.id)
|
||||||
|
chat_id = str(update.effective_chat.id)
|
||||||
|
connection = await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id=user_id,
|
||||||
|
external_account_name=self._telegram_display_name(update.effective_user),
|
||||||
|
workspace_id=chat_id,
|
||||||
|
workspace_name=None,
|
||||||
|
metadata={
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"chat_type": update.effective_chat.type,
|
||||||
|
"telegram_username": getattr(update.effective_user, "username", None),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||||
|
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="telegram",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
def _get_bot_username(self, context) -> str | None:
|
def _get_bot_username(self, context) -> str | None:
|
||||||
bot = getattr(context, "bot", None)
|
bot = getattr(context, "bot", None)
|
||||||
username = getattr(bot, "username", None)
|
username = getattr(bot, "username", None)
|
||||||
@@ -264,6 +334,11 @@ class TelegramChannel(Channel):
|
|||||||
"""Handle /start command."""
|
"""Handle /start command."""
|
||||||
if not self._check_user(update.effective_user.id):
|
if not self._check_user(update.effective_user.id):
|
||||||
return
|
return
|
||||||
|
args = getattr(context, "args", []) if context is not None else []
|
||||||
|
if args:
|
||||||
|
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||||
|
|
||||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||||
@@ -299,6 +374,7 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
@@ -341,6 +417,7 @@ class TelegramChannel(Channel):
|
|||||||
thread_ts=msg_id,
|
thread_ts=msg_id,
|
||||||
)
|
)
|
||||||
inbound.topic_id = topic_id
|
inbound.topic_id = topic_id
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
|
|
||||||
if self._main_loop and self._main_loop.is_running():
|
if self._main_loop and self._main_loop.is_running():
|
||||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||||
|
|||||||
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
|
|||||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
|
|||||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
self._load_state()
|
self._load_state()
|
||||||
|
|
||||||
async def start(self) -> None:
|
async def start(self) -> None:
|
||||||
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
|
|||||||
if thread_ts:
|
if thread_ts:
|
||||||
self._context_tokens_by_thread[thread_ts] = context_token
|
self._context_tokens_by_thread[thread_ts] = context_token
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
handled = await self._bind_connection_from_connect_code(
|
||||||
|
chat_id=chat_id,
|
||||||
|
context_token=context_token,
|
||||||
|
code=connect_code,
|
||||||
|
)
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
|
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=chat_id,
|
chat_id=chat_id,
|
||||||
user_id=chat_id,
|
user_id=chat_id,
|
||||||
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
inbound.topic_id = None
|
inbound.topic_id = None
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="wechat",
|
||||||
|
workspace_id=inbound.chat_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not chat_id:
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="wechat",
|
||||||
|
external_account_id=chat_id,
|
||||||
|
workspace_id=chat_id,
|
||||||
|
metadata={
|
||||||
|
"context_token": context_token,
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
||||||
|
if not context_token:
|
||||||
|
return
|
||||||
|
await self._send_text_message(
|
||||||
|
chat_id=chat_id,
|
||||||
|
context_token=context_token,
|
||||||
|
text=text,
|
||||||
|
client_id_prefix="deerflow-connect",
|
||||||
|
max_retries=1,
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_authenticated(self) -> bool:
|
async def _ensure_authenticated(self) -> bool:
|
||||||
async with self._auth_lock:
|
async with self._auth_lock:
|
||||||
if self._bot_token:
|
if self._bot_token:
|
||||||
|
|||||||
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
|
|||||||
from typing import Any, cast
|
from typing import Any, cast
|
||||||
|
|
||||||
from app.channels.base import Channel
|
from app.channels.base import Channel
|
||||||
from app.channels.commands import is_known_channel_command
|
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||||
|
from app.channels.connection_identity import attach_connection_identity
|
||||||
from app.channels.message_bus import (
|
from app.channels.message_bus import (
|
||||||
|
InboundMessage,
|
||||||
InboundMessageType,
|
InboundMessageType,
|
||||||
MessageBus,
|
MessageBus,
|
||||||
OutboundMessage,
|
OutboundMessage,
|
||||||
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
|
|||||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||||
self._ws_stream_ids: dict[str, str] = {}
|
self._ws_stream_ids: dict[str, str] = {}
|
||||||
self._working_message = "Working on it..."
|
self._working_message = "Working on it..."
|
||||||
|
self._connection_repo = config.get("connection_repo")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def supports_streaming(self) -> bool:
|
def supports_streaming(self) -> bool:
|
||||||
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
|
|||||||
|
|
||||||
user_id = (body.get("from") or {}).get("userid")
|
user_id = (body.get("from") or {}).get("userid")
|
||||||
|
|
||||||
|
connect_code = extract_connect_code(text)
|
||||||
|
if connect_code and self._connection_repo is not None:
|
||||||
|
handled = await self._bind_connection_from_connect_code(
|
||||||
|
frame=frame,
|
||||||
|
user_id=str(user_id or ""),
|
||||||
|
code=connect_code,
|
||||||
|
)
|
||||||
|
if handled:
|
||||||
|
return
|
||||||
|
|
||||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||||
inbound = self._make_inbound(
|
inbound = self._make_inbound(
|
||||||
chat_id=user_id, # keep user's conversation in memory
|
chat_id=user_id, # keep user's conversation in memory
|
||||||
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
inbound = await self._attach_connection_identity(inbound)
|
||||||
await self.bus.publish_inbound(inbound)
|
await self.bus.publish_inbound(inbound)
|
||||||
|
|
||||||
|
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||||
|
return await attach_connection_identity(
|
||||||
|
inbound,
|
||||||
|
repo=self._connection_repo,
|
||||||
|
provider="wecom",
|
||||||
|
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||||
|
fallback_without_workspace=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||||
|
if self._connection_repo is None or not code:
|
||||||
|
return False
|
||||||
|
|
||||||
|
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||||
|
if state is None:
|
||||||
|
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if not user_id:
|
||||||
|
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
body = frame.get("body", {}) or {}
|
||||||
|
workspace_id = str(body.get("aibotid") or "") or None
|
||||||
|
await self._connection_repo.upsert_connection(
|
||||||
|
owner_user_id=state["owner_user_id"],
|
||||||
|
provider="wecom",
|
||||||
|
external_account_id=user_id,
|
||||||
|
workspace_id=workspace_id,
|
||||||
|
metadata={
|
||||||
|
"aibotid": workspace_id,
|
||||||
|
"chattype": body.get("chattype"),
|
||||||
|
},
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||||
|
if not self._ws_client:
|
||||||
|
return
|
||||||
|
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||||
|
|
||||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||||
if not self._ws_client:
|
if not self._ws_client:
|
||||||
return
|
return
|
||||||
|
|||||||
+26
-14
@@ -6,6 +6,7 @@ from contextlib import asynccontextmanager
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
from app.gateway.auth_middleware import AuthMiddleware
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
from app.gateway.csrf_middleware import CSRFMiddleware, get_configured_cors_origins
|
||||||
@@ -15,6 +16,7 @@ from app.gateway.routers import (
|
|||||||
artifacts,
|
artifacts,
|
||||||
assistants_compat,
|
assistants_compat,
|
||||||
auth,
|
auth,
|
||||||
|
channel_connections,
|
||||||
channels,
|
channels,
|
||||||
feedback,
|
feedback,
|
||||||
mcp,
|
mcp,
|
||||||
@@ -172,6 +174,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
startup_config = get_app_config()
|
startup_config = get_app_config()
|
||||||
apply_logging_level(startup_config.log_level)
|
apply_logging_level(startup_config.log_level)
|
||||||
logger.info("Configuration loaded successfully")
|
logger.info("Configuration loaded successfully")
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
error_msg = f"Failed to load configuration during gateway startup: {e}"
|
||||||
logger.exception(error_msg)
|
logger.exception(error_msg)
|
||||||
@@ -182,21 +185,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||||
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||||
# that may be unreachable in restricted networks — see issue #3402).
|
# that may be unreachable in restricted networks — see issue #3402).
|
||||||
try:
|
# When memory.token_counting is "char", token counting never touches
|
||||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
||||||
|
# network-restricted deployments — see issue #3429).
|
||||||
|
if startup_config.memory.token_counting == "char":
|
||||||
|
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||||
|
|
||||||
warmed = await asyncio.wait_for(
|
warmed = await asyncio.wait_for(
|
||||||
asyncio.to_thread(warm_tiktoken_cache),
|
asyncio.to_thread(warm_tiktoken_cache),
|
||||||
timeout=5,
|
timeout=5,
|
||||||
)
|
)
|
||||||
if warmed:
|
if warmed:
|
||||||
logger.info("tiktoken encoding cache warmed successfully")
|
logger.info("tiktoken encoding cache warmed successfully")
|
||||||
else:
|
else:
|
||||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
|
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
|
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||||
|
|
||||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||||
async with langgraph_runtime(app, startup_config):
|
async with langgraph_runtime(app, startup_config):
|
||||||
@@ -376,6 +385,9 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
|||||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||||
app.include_router(suggestions.router)
|
app.include_router(suggestions.router)
|
||||||
|
|
||||||
|
# User-facing IM channel connection API is mounted at /api/channels
|
||||||
|
app.include_router(channel_connections.router)
|
||||||
|
|
||||||
# Channels API is mounted at /api/channels
|
# Channels API is mounted at /api/channels
|
||||||
app.include_router(channels.router)
|
app.include_router(channels.router)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Shared helpers for local/E2E auth-disabled mode."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
|
AUTH_DISABLED_ENV_VAR = "DEER_FLOW_AUTH_DISABLED"
|
||||||
|
AUTH_DISABLED_USER_ID = DEFAULT_USER_ID
|
||||||
|
AUTH_DISABLED_USER_EMAIL = "default@test.local"
|
||||||
|
|
||||||
|
AUTH_SOURCE_SESSION = "session"
|
||||||
|
AUTH_SOURCE_INTERNAL = "internal"
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED = "auth_disabled"
|
||||||
|
|
||||||
|
_PRODUCTION_ENV_VARS: tuple[str, ...] = ("DEER_FLOW_ENV", "ENVIRONMENT")
|
||||||
|
_PRODUCTION_ENV_VALUES: frozenset[str] = frozenset({"prod", "production"})
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def is_explicit_production_environment() -> bool:
|
||||||
|
return any(os.environ.get(name, "").strip().lower() in _PRODUCTION_ENV_VALUES for name in _PRODUCTION_ENV_VARS)
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled_requested() -> bool:
|
||||||
|
return os.environ.get(AUTH_DISABLED_ENV_VAR) == "1"
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_disabled() -> bool:
|
||||||
|
return is_auth_disabled_requested() and not is_explicit_production_environment()
|
||||||
|
|
||||||
|
|
||||||
|
def warn_if_auth_disabled_enabled() -> None:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
"%s=1 is active: authentication is bypassed and anonymous requests run as synthetic admin user %r. Do not enable this in shared or production deployments.",
|
||||||
|
AUTH_DISABLED_ENV_VAR,
|
||||||
|
AUTH_DISABLED_USER_ID,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_disabled_user():
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=AUTH_DISABLED_USER_ID,
|
||||||
|
email=AUTH_DISABLED_USER_EMAIL,
|
||||||
|
password_hash=None,
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
token_version=0,
|
||||||
|
)
|
||||||
@@ -17,6 +17,13 @@ from starlette.responses import JSONResponse
|
|||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
|
from app.gateway.auth_disabled import (
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
get_auth_disabled_user,
|
||||||
|
is_auth_disabled,
|
||||||
|
)
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
||||||
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
from app.gateway.internal_auth import INTERNAL_AUTH_HEADER_NAME, get_internal_user, is_valid_internal_auth_token
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
@@ -80,8 +87,38 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
if is_valid_internal_auth_token(request.headers.get(INTERNAL_AUTH_HEADER_NAME)):
|
||||||
internal_user = get_internal_user()
|
internal_user = get_internal_user()
|
||||||
|
|
||||||
|
auth_source = AUTH_SOURCE_SESSION
|
||||||
|
access_token = request.cookies.get("access_token")
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
# Non-public path: require session cookie
|
||||||
if internal_user is None and not request.cookies.get("access_token"):
|
if internal_user is not None:
|
||||||
|
user = internal_user
|
||||||
|
auth_source = AUTH_SOURCE_INTERNAL
|
||||||
|
elif access_token:
|
||||||
|
# Strict JWT validation: reject junk/expired tokens with 401
|
||||||
|
# right here instead of silently passing through. This closes
|
||||||
|
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
||||||
|
# without this, non-isolation routes like /api/models would
|
||||||
|
# accept any cookie-shaped string as authentication.
|
||||||
|
#
|
||||||
|
# We call the *strict* resolver so that fine-grained error
|
||||||
|
# codes (token_expired, token_invalid, user_not_found, …)
|
||||||
|
# propagate from AuthErrorCode, not get flattened into one
|
||||||
|
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
||||||
|
# bubble up, so we catch and render it as JSONResponse here.
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
try:
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if not is_auth_disabled():
|
||||||
|
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
elif is_auth_disabled():
|
||||||
|
user = get_auth_disabled_user()
|
||||||
|
auth_source = AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
else:
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=401,
|
status_code=401,
|
||||||
content={
|
content={
|
||||||
@@ -92,32 +129,12 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
|
||||||
# right here instead of silently passing through. This closes
|
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
|
||||||
# without this, non-isolation routes like /api/models would
|
|
||||||
# accept any cookie-shaped string as authentication.
|
|
||||||
#
|
|
||||||
# We call the *strict* resolver so that fine-grained error
|
|
||||||
# codes (token_expired, token_invalid, user_not_found, …)
|
|
||||||
# propagate from AuthErrorCode, not get flattened into one
|
|
||||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
|
||||||
# bubble up, so we catch and render it as JSONResponse here.
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
if internal_user is not None:
|
|
||||||
user = internal_user
|
|
||||||
else:
|
|
||||||
try:
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
# Stamp both request.state.user (for the contextvar pattern)
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
# and request.state.auth (so @require_permission's "auth is
|
||||||
# None" branch short-circuits instead of running the entire
|
# None" branch short-circuits instead of running the entire
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
# JWT-decode + DB-lookup pipeline a second time per request).
|
||||||
request.state.user = user
|
request.state.user = user
|
||||||
|
request.state.auth_source = auth_source
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
token = set_current_user(user)
|
token = set_current_user(user)
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -276,6 +276,11 @@ def require_permission(
|
|||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* user_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||||
|
|
||||||
|
if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
|||||||
from starlette.responses import JSONResponse
|
from starlette.responses import JSONResponse
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.gateway.auth_disabled import is_auth_disabled
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
@@ -38,6 +40,9 @@ def should_check_csrf(request: Request) -> bool:
|
|||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return False
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
path = request.url.path.rstrip("/")
|
||||||
# Exempt /api/v1/auth/me endpoint
|
# Exempt /api/v1/auth/me endpoint
|
||||||
if path == "/api/v1/auth/me":
|
if path == "/api/v1/auth/me":
|
||||||
|
|||||||
@@ -331,6 +331,17 @@ async def get_current_user_from_request(request: Request):
|
|||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
Raises HTTPException 401 if not authenticated.
|
||||||
"""
|
"""
|
||||||
|
state = getattr(request, "state", None)
|
||||||
|
state_user = getattr(state, "user", None)
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED, AUTH_SOURCE_INTERNAL, AUTH_SOURCE_SESSION
|
||||||
|
|
||||||
|
if state_user is not None and getattr(state, "auth_source", None) in {
|
||||||
|
AUTH_SOURCE_SESSION,
|
||||||
|
AUTH_SOURCE_AUTH_DISABLED,
|
||||||
|
AUTH_SOURCE_INTERNAL,
|
||||||
|
}:
|
||||||
|
return state_user
|
||||||
|
|
||||||
from app.gateway.auth import decode_token
|
from app.gateway.auth import decode_token
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||||
|
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||||
INTERNAL_SYSTEM_ROLE = "internal"
|
INTERNAL_SYSTEM_ROLE = "internal"
|
||||||
|
|
||||||
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
|||||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers() -> dict[str, str]:
|
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||||
|
if owner_user_id:
|
||||||
|
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||||
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
|||||||
def get_internal_user():
|
def get_internal_user():
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
"""Return the synthetic user used for trusted internal channel calls."""
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||||
|
"""Return the owner override for a trusted internal request, if present.
|
||||||
|
|
||||||
|
The header is ignored for normal browser/API callers. It is only honored
|
||||||
|
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||||
|
the synthetic internal user onto ``request.state.user``.
|
||||||
|
"""
|
||||||
|
user = getattr(getattr(request, "state", None), "user", None)
|
||||||
|
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
owner_user_id = owner_user_id.strip()
|
||||||
|
return owner_user_id or None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ from langgraph_sdk import Auth
|
|||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
from app.gateway.auth.errors import TokenError
|
||||||
from app.gateway.auth.jwt import decode_token
|
from app.gateway.auth.jwt import decode_token
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID, is_auth_disabled
|
||||||
from app.gateway.deps import get_local_provider
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
auth = Auth()
|
auth = Auth()
|
||||||
@@ -38,6 +39,9 @@ def _check_csrf(request) -> None:
|
|||||||
if method.upper() not in _CSRF_METHODS:
|
if method.upper() not in _CSRF_METHODS:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
header_token = request.headers.get("x-csrf-token")
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
@@ -66,6 +70,9 @@ async def authenticate(request):
|
|||||||
# are rejected early, even if the cookie carries a valid JWT.
|
# are rejected early, even if the cookie carries a valid JWT.
|
||||||
_check_csrf(request)
|
_check_csrf(request)
|
||||||
|
|
||||||
|
if is_auth_disabled():
|
||||||
|
return AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
token = request.cookies.get("access_token")
|
||||||
if not token:
|
if not token:
|
||||||
raise Auth.exceptions.HTTPException(
|
raise Auth.exceptions.HTTPException(
|
||||||
|
|||||||
@@ -341,9 +341,19 @@ async def change_password(request: Request, response: Response, body: ChangePass
|
|||||||
- Re-issues session cookie with new token_version
|
- Re-issues session cookie with new token_version
|
||||||
"""
|
"""
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
|
from app.gateway.auth_disabled import AUTH_SOURCE_AUTH_DISABLED
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if getattr(request.state, "auth_source", None) == AUTH_SOURCE_AUTH_DISABLED:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Password changes are not available when DEER_FLOW_AUTH_DISABLED=1.",
|
||||||
|
).model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
if user.password_hash is None:
|
if user.password_hash is None:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,600 @@
|
|||||||
|
"""Browser-facing APIs for user-owned IM channel bindings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException, Request, Response
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.channels.runtime_config_store import (
|
||||||
|
ChannelRuntimeConfigStore,
|
||||||
|
apply_runtime_connection_config,
|
||||||
|
merge_runtime_channel_configs,
|
||||||
|
)
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_STATE_TTL_SECONDS = 600
|
||||||
|
_MASKED_CREDENTIAL_VALUE = "********"
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCredentialFieldResponse(BaseModel):
|
||||||
|
name: str
|
||||||
|
label: str
|
||||||
|
type: str = "text"
|
||||||
|
required: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelProviderResponse(BaseModel):
|
||||||
|
provider: str
|
||||||
|
display_name: str
|
||||||
|
enabled: bool
|
||||||
|
configured: bool
|
||||||
|
connectable: bool
|
||||||
|
unavailable_reason: str | None = None
|
||||||
|
auth_mode: str
|
||||||
|
connection_status: str
|
||||||
|
credential_fields: list[ChannelCredentialFieldResponse] = Field(default_factory=list)
|
||||||
|
credential_values: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelProvidersResponse(BaseModel):
|
||||||
|
enabled: bool
|
||||||
|
providers: list[ChannelProviderResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
provider: str
|
||||||
|
status: str
|
||||||
|
external_account_id: str | None = None
|
||||||
|
external_account_name: str | None = None
|
||||||
|
workspace_id: str | None = None
|
||||||
|
workspace_name: str | None = None
|
||||||
|
scopes: list[str] = Field(default_factory=list)
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionsResponse(BaseModel):
|
||||||
|
connections: list[ChannelConnectionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectResponse(BaseModel):
|
||||||
|
provider: str
|
||||||
|
mode: str
|
||||||
|
url: str | None = None
|
||||||
|
code: str
|
||||||
|
instruction: str
|
||||||
|
expires_in: int
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelRuntimeConfigRequest(BaseModel):
|
||||||
|
values: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||||
|
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||||
|
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
||||||
|
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
||||||
|
"feishu": {"display_name": "Feishu", "auth_mode": "binding_code"},
|
||||||
|
"dingtalk": {"display_name": "DingTalk", "auth_mode": "binding_code"},
|
||||||
|
"wechat": {"display_name": "WeChat", "auth_mode": "binding_code"},
|
||||||
|
"wecom": {"display_name": "WeCom", "auth_mode": "binding_code"},
|
||||||
|
}
|
||||||
|
|
||||||
|
_CREDENTIAL_FIELDS: dict[str, tuple[dict[str, str], ...]] = {
|
||||||
|
"telegram": (
|
||||||
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||||
|
{"name": "bot_username", "label": "Bot username", "type": "text"},
|
||||||
|
),
|
||||||
|
"slack": (
|
||||||
|
{"name": "bot_token", "label": "Bot token", "type": "password"},
|
||||||
|
{"name": "app_token", "label": "App token", "type": "password"},
|
||||||
|
),
|
||||||
|
"discord": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||||
|
"feishu": (
|
||||||
|
{"name": "app_id", "label": "App ID", "type": "text"},
|
||||||
|
{"name": "app_secret", "label": "App secret", "type": "password"},
|
||||||
|
),
|
||||||
|
"dingtalk": (
|
||||||
|
{"name": "client_id", "label": "Client ID", "type": "text"},
|
||||||
|
{"name": "client_secret", "label": "Client secret", "type": "password"},
|
||||||
|
),
|
||||||
|
"wechat": ({"name": "bot_token", "label": "Bot token", "type": "password"},),
|
||||||
|
"wecom": (
|
||||||
|
{"name": "bot_id", "label": "Bot ID", "type": "text"},
|
||||||
|
{"name": "bot_secret", "label": "Bot secret", "type": "password"},
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
||||||
|
"telegram": ("bot_token",),
|
||||||
|
"slack": ("bot_token", "app_token"),
|
||||||
|
"discord": ("bot_token",),
|
||||||
|
"feishu": ("app_id", "app_secret"),
|
||||||
|
"dingtalk": ("client_id", "client_secret"),
|
||||||
|
"wechat": ("bot_token",),
|
||||||
|
"wecom": ("bot_id", "bot_secret"),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id(request: Request) -> str:
|
||||||
|
user = getattr(request.state, "user", None)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return str(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_app_config():
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
|
|
||||||
|
return get_app_config()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_runtime_config_store(request: Request) -> ChannelRuntimeConfigStore:
|
||||||
|
store = getattr(request.app.state, "channel_runtime_config_store", None)
|
||||||
|
if isinstance(store, ChannelRuntimeConfigStore):
|
||||||
|
return store
|
||||||
|
store = ChannelRuntimeConfigStore()
|
||||||
|
request.app.state.channel_runtime_config_store = store
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||||
|
config = getattr(request.app.state, "channel_connections_config", None)
|
||||||
|
if not isinstance(config, ChannelConnectionsConfig):
|
||||||
|
config = _get_app_config().channel_connections
|
||||||
|
config = apply_runtime_connection_config(config, store=_get_runtime_config_store(request))
|
||||||
|
request.app.state.channel_connections_config = config
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||||
|
state_config = getattr(request.app.state, "channels_config", None)
|
||||||
|
if isinstance(state_config, dict):
|
||||||
|
return state_config
|
||||||
|
|
||||||
|
result = _load_channels_config(request, _get_channel_connections_config(request))
|
||||||
|
request.app.state.channels_config = result
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _load_channels_config(request: Request, config: ChannelConnectionsConfig) -> dict[str, Any]:
|
||||||
|
app_config = _get_app_config()
|
||||||
|
extra = app_config.model_extra or {}
|
||||||
|
channels_config = extra.get("channels")
|
||||||
|
result = dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||||
|
merge_runtime_channel_configs(
|
||||||
|
result,
|
||||||
|
config,
|
||||||
|
store=_get_runtime_config_store(request),
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||||
|
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||||
|
if isinstance(repo, ChannelConnectionRepository):
|
||||||
|
return repo
|
||||||
|
|
||||||
|
sf = get_session_factory()
|
||||||
|
if sf is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||||
|
|
||||||
|
repo = ChannelConnectionRepository(sf)
|
||||||
|
request.app.state.channel_connection_repo = repo
|
||||||
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||||
|
provider_config = getattr(config, provider, None)
|
||||||
|
if provider_config is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return provider_config
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||||
|
return False
|
||||||
|
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_unavailable_reason(provider: str) -> str:
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
display_name = meta["display_name"] if meta else provider
|
||||||
|
return f"Enter the required {display_name} credentials to connect this channel."
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_not_running_reason(provider: str) -> str:
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
display_name = meta["display_name"] if meta else provider
|
||||||
|
return f"{display_name} channel is configured but is not running. Check the credentials and save this channel again."
|
||||||
|
|
||||||
|
|
||||||
|
def _runtime_channel_running(provider: str) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to inspect channel service status", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
status = service.get_status()
|
||||||
|
except Exception:
|
||||||
|
logger.debug("Unable to read channel service status", exc_info=True)
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not status.get("service_running"):
|
||||||
|
return False
|
||||||
|
channel_status = status.get("channels", {}).get(provider)
|
||||||
|
if not isinstance(channel_status, dict):
|
||||||
|
return None
|
||||||
|
return bool(channel_status.get("running"))
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_unavailable_reason(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
) -> str | None:
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
return None
|
||||||
|
if not provider_config.configured:
|
||||||
|
return _runtime_unavailable_reason(provider)
|
||||||
|
if not _runtime_channel_configured(provider, channels_config):
|
||||||
|
return _runtime_unavailable_reason(provider)
|
||||||
|
if _runtime_channel_running(provider) is False:
|
||||||
|
return _runtime_not_running_reason(provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_status(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
) -> tuple[dict[str, bool], str | None]:
|
||||||
|
declared = config.provider_status(provider)
|
||||||
|
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
||||||
|
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
||||||
|
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
||||||
|
|
||||||
|
|
||||||
|
def _new_binding_code() -> str:
|
||||||
|
return secrets.token_urlsafe(16)
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_state(
|
||||||
|
repo: ChannelConnectionRepository,
|
||||||
|
*,
|
||||||
|
owner_user_id: str,
|
||||||
|
provider: str,
|
||||||
|
) -> str:
|
||||||
|
state = _new_binding_code()
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||||
|
)
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
def _connect_instruction(provider: str, code: str) -> str:
|
||||||
|
if provider == "telegram":
|
||||||
|
return f"Send /start {code} to the DeerFlow Telegram bot."
|
||||||
|
meta = _PROVIDER_META.get(provider)
|
||||||
|
if meta is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return f"Send /connect {code} to the DeerFlow {meta['display_name']} bot."
|
||||||
|
|
||||||
|
|
||||||
|
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
||||||
|
if provider == "telegram":
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
||||||
|
if _PROVIDER_META.get(provider, {}).get("auth_mode") == "binding_code":
|
||||||
|
return None
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
|
||||||
|
|
||||||
|
def _connection_updated_at(connection: dict[str, Any]) -> datetime:
|
||||||
|
value = connection.get("updated_at")
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
return value if value.tzinfo is not None else value.replace(tzinfo=UTC)
|
||||||
|
if isinstance(value, str) and value:
|
||||||
|
try:
|
||||||
|
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return datetime.min.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
|
||||||
|
def _newest_connection_by_provider(connections: list[dict[str, Any]]) -> dict[str, dict[str, Any]]:
|
||||||
|
by_provider: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in connections:
|
||||||
|
existing = by_provider.get(item["provider"])
|
||||||
|
if existing is None or _connection_updated_at(item) > _connection_updated_at(existing):
|
||||||
|
by_provider[item["provider"]] = item
|
||||||
|
return by_provider
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_fields(provider: str) -> list[ChannelCredentialFieldResponse]:
|
||||||
|
fields = _CREDENTIAL_FIELDS.get(provider)
|
||||||
|
if fields is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||||
|
return [ChannelCredentialFieldResponse(**field) for field in fields]
|
||||||
|
|
||||||
|
|
||||||
|
def _credential_values(provider: str, channels_config: dict[str, Any]) -> dict[str, str]:
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if not isinstance(runtime_config, dict):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
values: dict[str, str] = {}
|
||||||
|
for field in _credential_fields(provider):
|
||||||
|
value = str(runtime_config.get(field.name) or "").strip()
|
||||||
|
if not value:
|
||||||
|
continue
|
||||||
|
values[field.name] = _MASKED_CREDENTIAL_VALUE if field.type == "password" else value
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
def _provider_response(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
channels_config: dict[str, Any],
|
||||||
|
provider: str,
|
||||||
|
meta: dict[str, str],
|
||||||
|
connection: dict[str, Any] | None = None,
|
||||||
|
) -> ChannelProviderResponse:
|
||||||
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||||
|
if connection:
|
||||||
|
connection_status = connection["status"]
|
||||||
|
elif status["configured"] and unavailable_reason is None:
|
||||||
|
connection_status = "connected"
|
||||||
|
else:
|
||||||
|
connection_status = "not_connected"
|
||||||
|
credential_values = _credential_values(provider, channels_config)
|
||||||
|
if provider == "telegram" and not credential_values.get("bot_username"):
|
||||||
|
bot_username = str(_provider_config(config, provider).bot_username or "").strip()
|
||||||
|
if bot_username:
|
||||||
|
credential_values["bot_username"] = bot_username
|
||||||
|
return ChannelProviderResponse(
|
||||||
|
provider=provider,
|
||||||
|
display_name=meta["display_name"],
|
||||||
|
enabled=status["enabled"],
|
||||||
|
configured=status["configured"],
|
||||||
|
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||||
|
unavailable_reason=unavailable_reason,
|
||||||
|
auth_mode=meta["auth_mode"],
|
||||||
|
connection_status=connection_status,
|
||||||
|
credential_fields=_credential_fields(provider),
|
||||||
|
credential_values=credential_values,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _required_runtime_values(
|
||||||
|
provider: str,
|
||||||
|
values: dict[str, str],
|
||||||
|
existing_config: dict[str, Any] | None = None,
|
||||||
|
) -> dict[str, str]:
|
||||||
|
fields = _credential_fields(provider)
|
||||||
|
cleaned: dict[str, str] = {}
|
||||||
|
missing: list[str] = []
|
||||||
|
existing_config = existing_config or {}
|
||||||
|
for field in fields:
|
||||||
|
raw_value = values.get(field.name, "")
|
||||||
|
if field.type == "password" and raw_value == _MASKED_CREDENTIAL_VALUE:
|
||||||
|
existing_value = str(existing_config.get(field.name) or "").strip()
|
||||||
|
if existing_value:
|
||||||
|
cleaned[field.name] = existing_value
|
||||||
|
continue
|
||||||
|
value = raw_value.strip() if isinstance(raw_value, str) else str(raw_value or "").strip()
|
||||||
|
if field.required and not value:
|
||||||
|
missing.append(field.label)
|
||||||
|
cleaned[field.name] = value
|
||||||
|
if missing:
|
||||||
|
raise HTTPException(status_code=400, detail=f"Missing required channel configuration: {', '.join(missing)}")
|
||||||
|
return cleaned
|
||||||
|
|
||||||
|
|
||||||
|
async def _restart_runtime_channel_if_available(provider: str, runtime_config: dict[str, Any]) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel service while configuring %s", provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def _sync_runtime_channel_after_removal(provider: str, channels_config: dict[str, Any]) -> bool | None:
|
||||||
|
try:
|
||||||
|
from app.channels.service import get_channel_service
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Failed to import channel service while disconnecting %s", provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
service = get_channel_service()
|
||||||
|
if service is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
runtime_config = channels_config.get(provider)
|
||||||
|
if isinstance(runtime_config, dict) and runtime_config.get("enabled", False):
|
||||||
|
return await service.configure_channel(provider, runtime_config)
|
||||||
|
return await service.remove_channel(provider)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||||
|
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
channels_config = _get_channels_config(request)
|
||||||
|
repo = None
|
||||||
|
if config.enabled:
|
||||||
|
try:
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 503:
|
||||||
|
raise
|
||||||
|
owner_user_id = _get_user_id(request)
|
||||||
|
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||||
|
by_provider = _newest_connection_by_provider(connections)
|
||||||
|
|
||||||
|
providers: list[ChannelProviderResponse] = []
|
||||||
|
for provider, meta in _PROVIDER_META.items():
|
||||||
|
if not config.provider_status(provider)["enabled"]:
|
||||||
|
continue
|
||||||
|
connection = by_provider.get(provider)
|
||||||
|
providers.append(_provider_response(config, channels_config, provider, meta, connection))
|
||||||
|
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||||
|
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
return ChannelConnectionsResponse(connections=[])
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
rows = await repo.list_connections(_get_user_id(request))
|
||||||
|
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/connections/{connection_id}", status_code=204)
|
||||||
|
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
disconnected = await repo.disconnect_connection(
|
||||||
|
connection_id=connection_id,
|
||||||
|
owner_user_id=_get_user_id(request),
|
||||||
|
)
|
||||||
|
if not disconnected:
|
||||||
|
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||||
|
return Response(status_code=204)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||||
|
async def disconnect_channel_provider_runtime(provider: str, request: Request) -> ChannelProviderResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|
||||||
|
owner_user_id = _get_user_id(request)
|
||||||
|
try:
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
except HTTPException as exc:
|
||||||
|
if exc.status_code != 503:
|
||||||
|
raise
|
||||||
|
repo = None
|
||||||
|
|
||||||
|
if repo is not None:
|
||||||
|
for connection in await repo.list_connections(owner_user_id):
|
||||||
|
if connection["provider"] == provider and connection["status"] != "revoked":
|
||||||
|
await repo.disconnect_connection(
|
||||||
|
connection_id=connection["id"],
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
_get_runtime_config_store(request).remove_provider_config(provider)
|
||||||
|
channels_config = _load_channels_config(request, config)
|
||||||
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
stopped = await _sync_runtime_channel_after_removal(provider, channels_config)
|
||||||
|
if stopped is False:
|
||||||
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to stop {display_name} channel. Try again.")
|
||||||
|
|
||||||
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||||
|
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
channels_config = _get_channels_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||||
|
if not status["enabled"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
if unavailable_reason:
|
||||||
|
raise HTTPException(status_code=400, detail=unavailable_reason)
|
||||||
|
if not status["configured"]:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||||
|
|
||||||
|
repo = _get_repository(request, config)
|
||||||
|
code = await _create_state(
|
||||||
|
repo,
|
||||||
|
owner_user_id=_get_user_id(request),
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
return ChannelConnectResponse(
|
||||||
|
provider=provider,
|
||||||
|
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||||
|
url=_connect_url(config, provider, code),
|
||||||
|
code=code,
|
||||||
|
instruction=_connect_instruction(provider, code),
|
||||||
|
expires_in=_STATE_TTL_SECONDS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{provider}/runtime-config", response_model=ChannelProviderResponse)
|
||||||
|
async def configure_channel_provider_runtime(
|
||||||
|
provider: str,
|
||||||
|
body: ChannelRuntimeConfigRequest,
|
||||||
|
request: Request,
|
||||||
|
) -> ChannelProviderResponse:
|
||||||
|
config = _get_channel_connections_config(request)
|
||||||
|
if not config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||||
|
|
||||||
|
provider_config = _provider_config(config, provider)
|
||||||
|
if not provider_config.enabled:
|
||||||
|
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||||
|
|
||||||
|
channels_config = _get_channels_config(request)
|
||||||
|
existing = channels_config.get(provider)
|
||||||
|
runtime_config = dict(existing) if isinstance(existing, dict) else {}
|
||||||
|
values = _required_runtime_values(provider, body.values, runtime_config)
|
||||||
|
runtime_config["enabled"] = True
|
||||||
|
|
||||||
|
for key in _RUNTIME_REQUIREMENTS[provider]:
|
||||||
|
runtime_config[key] = values[key]
|
||||||
|
|
||||||
|
if provider == "telegram":
|
||||||
|
runtime_config["bot_username"] = values["bot_username"]
|
||||||
|
provider_config.bot_username = values["bot_username"]
|
||||||
|
request.app.state.channel_connections_config = config
|
||||||
|
|
||||||
|
channels_config[provider] = runtime_config
|
||||||
|
request.app.state.channels_config = channels_config
|
||||||
|
|
||||||
|
started = await _restart_runtime_channel_if_available(provider, runtime_config)
|
||||||
|
if started is False:
|
||||||
|
display_name = _PROVIDER_META[provider]["display_name"]
|
||||||
|
raise HTTPException(status_code=400, detail=f"Failed to start {display_name} channel. Check the values and try again.")
|
||||||
|
|
||||||
|
_get_runtime_config_store(request).set_provider_config(provider, runtime_config)
|
||||||
|
|
||||||
|
return _provider_response(config, channels_config, provider, _PROVIDER_META[provider])
|
||||||
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
|
|||||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||||
|
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
||||||
|
|
||||||
|
|
||||||
class MemoryStatusResponse(BaseModel):
|
class MemoryStatusResponse(BaseModel):
|
||||||
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
"max_facts": 100,
|
"max_facts": 100,
|
||||||
"fact_confidence_threshold": 0.7,
|
"fact_confidence_threshold": 0.7,
|
||||||
"injection_enabled": true,
|
"injection_enabled": true,
|
||||||
"max_injection_tokens": 2000
|
"max_injection_tokens": 2000,
|
||||||
|
"token_counting": "tiktoken"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
|||||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||||
injection_enabled=config.injection_enabled,
|
injection_enabled=config.injection_enabled,
|
||||||
max_injection_tokens=config.max_injection_tokens,
|
max_injection_tokens=config.max_injection_tokens,
|
||||||
|
token_counting=config.token_counting,
|
||||||
),
|
),
|
||||||
data=MemoryResponse(**memory_data),
|
data=MemoryResponse(**memory_data),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer
|
from app.gateway.deps import get_checkpointer
|
||||||
|
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
thread_store = get_thread_store(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = now_iso()
|
now = now_iso()
|
||||||
|
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
|
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_store.get(thread_id)
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
|
if existing_record is None and thread_owner_user_id:
|
||||||
|
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_record is not None:
|
||||||
|
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||||
|
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||||
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
await thread_store.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
|
**thread_owner_kwargs,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.messages.utils import convert_to_messages
|
from langchain_core.messages.utils import convert_to_messages
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
@@ -35,6 +36,7 @@ from deerflow.runtime import (
|
|||||||
run_agent,
|
run_agent,
|
||||||
)
|
)
|
||||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||||
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -315,72 +317,100 @@ async def start_run(
|
|||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
record = await run_mgr.create_or_reject(
|
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||||
thread_id,
|
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||||
body.assistant_id,
|
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||||
on_disconnect=disconnect,
|
# before any run is created, so one user cannot start runs on (or read /wait
|
||||||
metadata=body.metadata or {},
|
# checkpoint state from) another user's thread. Missing rows (auto-created
|
||||||
kwargs={"input": body.input, "config": body.config},
|
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||||
multitask_strategy=body.multitask_strategy,
|
# via check_access; only a thread already owned by another user is rejected
|
||||||
model_name=model_name,
|
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||||
)
|
# channel runs act on behalf of IM users they do not own (see
|
||||||
except ConflictError as exc:
|
# inject_authenticated_user_context), so the internal system role is exempt.
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
user = getattr(request.state, "user", None)
|
||||||
except UnsupportedStrategyError as exc:
|
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
|
||||||
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
# Upsert thread metadata so the thread appears in /threads/search,
|
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||||
# even for threads that were never explicitly created via POST /threads
|
|
||||||
# (e.g. stateless runs).
|
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
try:
|
||||||
if existing is None:
|
record = await run_mgr.create_or_reject(
|
||||||
await run_ctx.thread_store.create(
|
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
body.assistant_id,
|
||||||
metadata=body.metadata,
|
on_disconnect=disconnect,
|
||||||
|
metadata=body.metadata or {},
|
||||||
|
kwargs={"input": body.input, "config": body.config},
|
||||||
|
multitask_strategy=body.multitask_strategy,
|
||||||
|
model_name=model_name,
|
||||||
|
user_id=owner_user_id,
|
||||||
)
|
)
|
||||||
else:
|
except ConflictError as exc:
|
||||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
except Exception:
|
except UnsupportedStrategyError as exc:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
# Upsert thread metadata so the thread appears in /threads/search,
|
||||||
graph_input = normalize_input(body.input)
|
# even for threads that were never explicitly created via POST /threads
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
# (e.g. stateless runs).
|
||||||
|
try:
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None and owner_user_id:
|
||||||
|
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_existing is not None:
|
||||||
|
if unscoped_existing.get("user_id") != owner_user_id:
|
||||||
|
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
await run_ctx.thread_store.create(
|
||||||
|
thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
graph_input = normalize_input(body.input)
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
|
||||||
inject_authenticated_user_context(config, request)
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||||
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
run_agent(
|
|
||||||
bridge,
|
task = asyncio.create_task(
|
||||||
run_mgr,
|
run_agent(
|
||||||
record,
|
bridge,
|
||||||
ctx=run_ctx,
|
run_mgr,
|
||||||
agent_factory=agent_factory,
|
record,
|
||||||
graph_input=graph_input,
|
ctx=run_ctx,
|
||||||
config=config,
|
agent_factory=agent_factory,
|
||||||
stream_modes=stream_modes,
|
graph_input=graph_input,
|
||||||
stream_subgraphs=body.stream_subgraphs,
|
config=config,
|
||||||
interrupt_before=body.interrupt_before,
|
stream_modes=stream_modes,
|
||||||
interrupt_after=body.interrupt_after,
|
stream_subgraphs=body.stream_subgraphs,
|
||||||
|
interrupt_before=body.interrupt_before,
|
||||||
|
interrupt_after=body.interrupt_after,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
record.task = task
|
||||||
record.task = task
|
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_store.update_display_name
|
# title from the checkpoint and calls thread_store.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
finally:
|
||||||
|
if owner_context_token is not None:
|
||||||
|
reset_current_user(owner_context_token)
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
async def sse_consumer(
|
||||||
|
|||||||
@@ -0,0 +1,121 @@
|
|||||||
|
# IM Channel Connections
|
||||||
|
|
||||||
|
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, Discord, Feishu/Lark, DingTalk, WeChat, and WeCom. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
|
||||||
|
|
||||||
|
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
Configure the actual IM bots under the existing `channels` block:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
channels:
|
||||||
|
telegram:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $TELEGRAM_BOT_TOKEN
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $SLACK_BOT_TOKEN
|
||||||
|
app_token: $SLACK_APP_TOKEN
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $DISCORD_BOT_TOKEN
|
||||||
|
|
||||||
|
feishu:
|
||||||
|
enabled: true
|
||||||
|
app_id: $FEISHU_APP_ID
|
||||||
|
app_secret: $FEISHU_APP_SECRET
|
||||||
|
|
||||||
|
dingtalk:
|
||||||
|
enabled: true
|
||||||
|
client_id: $DINGTALK_CLIENT_ID
|
||||||
|
client_secret: $DINGTALK_CLIENT_SECRET
|
||||||
|
|
||||||
|
wechat:
|
||||||
|
enabled: true
|
||||||
|
bot_token: $WECHAT_BOT_TOKEN
|
||||||
|
|
||||||
|
wecom:
|
||||||
|
enabled: true
|
||||||
|
bot_id: $WECOM_BOT_ID
|
||||||
|
bot_secret: $WECOM_BOT_SECRET
|
||||||
|
```
|
||||||
|
|
||||||
|
Then enable user bindings in `channel_connections`:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
channel_connections:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
telegram:
|
||||||
|
enabled: true
|
||||||
|
bot_username: $TELEGRAM_BOT_USERNAME
|
||||||
|
|
||||||
|
slack:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
discord:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
feishu:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
dingtalk:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
wechat:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
wecom:
|
||||||
|
enabled: true
|
||||||
|
```
|
||||||
|
|
||||||
|
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
|
||||||
|
|
||||||
|
## Connect Flow
|
||||||
|
|
||||||
|
Telegram:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
|
||||||
|
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
|
||||||
|
|
||||||
|
Slack:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
|
||||||
|
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
|
||||||
|
|
||||||
|
Discord:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
|
||||||
|
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
|
||||||
|
|
||||||
|
Feishu/Lark, DingTalk, WeChat, and WeCom:
|
||||||
|
|
||||||
|
- The frontend creates a short one-time code.
|
||||||
|
- The UI shows `Send /connect <code> to the DeerFlow <Provider> bot.`
|
||||||
|
- The already-running long-connection or polling worker receives the message and binds the platform user/workspace identity to the current DeerFlow user.
|
||||||
|
|
||||||
|
Codes expire after 10 minutes and are single-use.
|
||||||
|
|
||||||
|
## Runtime Model
|
||||||
|
|
||||||
|
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
||||||
|
|
||||||
|
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
|
||||||
|
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
|
||||||
|
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
|
||||||
|
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
|
||||||
|
|
||||||
|
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
|
||||||
|
|
||||||
|
## Security Notes
|
||||||
|
|
||||||
|
- Browser APIs remain authenticated and CSRF-protected.
|
||||||
|
- Connect codes are random, short-lived, and single-use.
|
||||||
|
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
||||||
|
- This implementation does not add public provider callback or webhook routes.
|
||||||
@@ -31,7 +31,8 @@ Current injection format:
|
|||||||
|
|
||||||
Token counting:
|
Token counting:
|
||||||
- Uses `tiktoken` (`cl100k_base`) when available
|
- Uses `tiktoken` (`cl100k_base`) when available
|
||||||
- Falls back to `len(text) // 4` if tokenizer import fails
|
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
||||||
|
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
||||||
|
|
||||||
## Known Gap
|
## Known Gap
|
||||||
|
|
||||||
|
|||||||
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
|||||||
return ""
|
return ""
|
||||||
|
|
||||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
memory_content = format_memory_for_injection(
|
||||||
|
memory_data,
|
||||||
|
max_tokens=config.max_injection_tokens,
|
||||||
|
use_tiktoken=(config.token_counting == "tiktoken"),
|
||||||
|
)
|
||||||
|
|
||||||
if not memory_content.strip():
|
if not memory_content.strip():
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import re
|
import re
|
||||||
from typing import Any
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any, cast
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
|
|||||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||||
# (potentially slow) first ``get_encoding`` call.
|
# (potentially slow) first ``get_encoding`` call.
|
||||||
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
#
|
||||||
|
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||||
|
# a network-restricted environment does not re-attempt the blocking BPE
|
||||||
|
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||||
|
# failure is allowed to expire so a transient network outage can self-heal back
|
||||||
|
# to accurate tiktoken counting without a process restart. A load already in
|
||||||
|
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||||
|
# fall back immediately instead of spawning more blocking
|
||||||
|
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||||
|
# config to skip tiktoken entirely.
|
||||||
|
_TIKTOKEN_ENCODING_MISSING = object()
|
||||||
|
_TIKTOKEN_ENCODING_LOADING = object()
|
||||||
|
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||||
|
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||||
|
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||||
|
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||||
|
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||||
|
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||||
|
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||||
|
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||||
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
|
|||||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||||
The caller must therefore be prepared for this to block and should run it
|
The caller must therefore be prepared for this to block and should run it
|
||||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||||
|
|
||||||
|
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||||
|
back immediately to character-based estimation instead of re-triggering the
|
||||||
|
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||||
|
so a transient outage can self-heal without a restart. A load already in
|
||||||
|
progress is also remembered so that a timed-out caller does not leave a
|
||||||
|
window where later requests start more blocking ``get_encoding`` calls.
|
||||||
"""
|
"""
|
||||||
if not TIKTOKEN_AVAILABLE:
|
if not TIKTOKEN_AVAILABLE:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
cached = _tiktoken_encoding_cache.get(encoding_name)
|
with _tiktoken_encoding_cache_lock:
|
||||||
if cached is not None:
|
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||||
return cached
|
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||||
|
return None
|
||||||
|
if isinstance(cached, tuple):
|
||||||
|
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||||
|
_, failed_at = cached
|
||||||
|
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||||
|
return None
|
||||||
|
cached = _TIKTOKEN_ENCODING_MISSING
|
||||||
|
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||||
|
return cast("tiktoken.Encoding", cached)
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||||
|
|
||||||
try:
|
try:
|
||||||
encoding = tiktoken.get_encoding(encoding_name)
|
encoding = tiktoken.get_encoding(encoding_name)
|
||||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
|
||||||
return encoding
|
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
with _tiktoken_encoding_cache_lock:
|
||||||
|
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||||
|
return encoding
|
||||||
|
|
||||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
|
||||||
|
def _char_based_token_estimate(text: str) -> int:
|
||||||
|
"""Network-free token estimate that accounts for CJK density.
|
||||||
|
|
||||||
|
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||||
|
(~4 chars per token) but significantly under-estimates token counts for
|
||||||
|
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||||
|
characters per token. Counting CJK characters separately (~2 chars per
|
||||||
|
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||||
|
content.
|
||||||
|
"""
|
||||||
|
cjk = sum(
|
||||||
|
1
|
||||||
|
for ch in text
|
||||||
|
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||||
|
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||||
|
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||||
|
)
|
||||||
|
return (len(text) - cjk) // 4 + cjk // 2
|
||||||
|
|
||||||
|
|
||||||
|
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||||
"""Count tokens in text using tiktoken.
|
"""Count tokens in text using tiktoken.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The text to count tokens for.
|
text: The text to count tokens for.
|
||||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||||
|
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||||
|
network-free character-based estimate. This guarantees no BPE
|
||||||
|
download is attempted (see ``memory.token_counting`` config).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The number of tokens in the text.
|
The number of tokens in the text.
|
||||||
"""
|
"""
|
||||||
|
if not use_tiktoken:
|
||||||
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
encoding = _get_tiktoken_encoding(encoding_name)
|
encoding = _get_tiktoken_encoding(encoding_name)
|
||||||
if encoding is None:
|
if encoding is None:
|
||||||
# Fallback to character-based estimation if tiktoken is not available
|
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||||
# or the encoding failed to load.
|
# available or the encoding failed to load.
|
||||||
return len(text) // 4
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return len(encoding.encode(text))
|
return len(encoding.encode(text))
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fallback to character-based estimation on error
|
# Fallback to CJK-aware character estimation on error.
|
||||||
return len(text) // 4
|
return _char_based_token_estimate(text)
|
||||||
|
|
||||||
|
|
||||||
def warm_tiktoken_cache() -> bool:
|
def warm_tiktoken_cache() -> bool:
|
||||||
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
|||||||
return max(0.0, min(1.0, confidence))
|
return max(0.0, min(1.0, confidence))
|
||||||
|
|
||||||
|
|
||||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||||
"""Format memory data for injection into system prompt.
|
"""Format memory data for injection into system prompt.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
memory_data: The memory data dictionary.
|
memory_data: The memory data dictionary.
|
||||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||||
|
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||||
|
character-based estimate instead of tiktoken (see
|
||||||
|
``memory.token_counting`` config). Defaults to ``True``.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Formatted memory string for system prompt injection.
|
Formatted memory string for system prompt injection.
|
||||||
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
# Compute token count for existing sections once, then account
|
# Compute token count for existing sections once, then account
|
||||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||||
base_text = "\n\n".join(sections)
|
base_text = "\n\n".join(sections)
|
||||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||||
# Account for the separator between existing sections and the facts section.
|
# Account for the separator between existing sections and the facts section.
|
||||||
facts_header = "Facts:\n"
|
facts_header = "Facts:\n"
|
||||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||||
running_tokens = base_tokens + separator_tokens
|
running_tokens = base_tokens + separator_tokens
|
||||||
|
|
||||||
fact_lines: list[str] = []
|
fact_lines: list[str] = []
|
||||||
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
# Each additional line is preceded by a newline (except the first).
|
# Each additional line is preceded by a newline (except the first).
|
||||||
line_text = ("\n" + line) if fact_lines else line
|
line_text = ("\n" + line) if fact_lines else line
|
||||||
line_tokens = _count_tokens(line_text)
|
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||||
|
|
||||||
if running_tokens + line_tokens <= max_tokens:
|
if running_tokens + line_tokens <= max_tokens:
|
||||||
fact_lines.append(line)
|
fact_lines.append(line)
|
||||||
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
|||||||
|
|
||||||
result = "\n\n".join(sections)
|
result = "\n\n".join(sections)
|
||||||
|
|
||||||
# Use accurate token counting with tiktoken
|
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||||
token_count = _count_tokens(result)
|
# when use_tiktoken is False).
|
||||||
|
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||||
if token_count > max_tokens:
|
if token_count > max_tokens:
|
||||||
# Truncate to fit within token limit
|
# Truncate to fit within token limit
|
||||||
# Estimate characters to remove based on token ratio
|
# Estimate characters to remove based on token ratio
|
||||||
|
|||||||
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
|
|||||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||||
"injection_enabled": config.injection_enabled,
|
"injection_enabled": config.injection_enabled,
|
||||||
"max_injection_tokens": config.max_injection_tokens,
|
"max_injection_tokens": config.max_injection_tokens,
|
||||||
|
"token_counting": config.token_counting,
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_memory_status(self) -> dict:
|
def get_memory_status(self) -> dict:
|
||||||
|
|||||||
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
|||||||
paths = get_paths()
|
paths = get_paths()
|
||||||
effective_user = user_id or get_effective_user_id()
|
effective_user = user_id or get_effective_user_id()
|
||||||
user_path = paths.user_agent_dir(effective_user, name)
|
user_path = paths.user_agent_dir(effective_user, name)
|
||||||
if user_path.exists():
|
# Require config.yaml to confirm this is a genuine agent directory,
|
||||||
|
# not a leftover from memory/storage writes (see #3390).
|
||||||
|
if user_path.exists() and (user_path / "config.yaml").exists():
|
||||||
return user_path
|
return user_path
|
||||||
|
|
||||||
legacy_path = paths.agent_dir(name)
|
legacy_path = paths.agent_dir(name)
|
||||||
if legacy_path.exists():
|
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
||||||
return legacy_path
|
return legacy_path
|
||||||
|
|
||||||
return user_path
|
return user_path
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
|||||||
|
|
||||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||||
from deerflow.config.database_config import DatabaseConfig
|
from deerflow.config.database_config import DatabaseConfig
|
||||||
from deerflow.config.extensions_config import ExtensionsConfig
|
from deerflow.config.extensions_config import ExtensionsConfig
|
||||||
@@ -116,6 +117,7 @@ class AppConfig(BaseModel):
|
|||||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||||
|
channel_connections: ChannelConnectionsConfig = Field(default_factory=ChannelConnectionsConfig, description="User-facing IM channel connection configuration")
|
||||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||||
model_config = ConfigDict(extra="allow")
|
model_config = ConfigDict(extra="allow")
|
||||||
|
|||||||
@@ -0,0 +1,61 @@
|
|||||||
|
"""Configuration for user-owned IM channel connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class SlackChannelConnectionConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def configured(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class TelegramChannelConnectionConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
bot_username: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def configured(self) -> bool:
|
||||||
|
return bool(self.bot_username)
|
||||||
|
|
||||||
|
|
||||||
|
class DiscordChannelConnectionConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def configured(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class BindingCodeChannelConnectionConfig(BaseModel):
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def configured(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionsConfig(BaseModel):
|
||||||
|
"""Top-level config for browser-connectable IM channels."""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
||||||
|
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
|
||||||
|
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
|
||||||
|
feishu: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||||
|
dingtalk: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||||
|
wechat: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||||
|
wecom: BindingCodeChannelConnectionConfig = Field(default_factory=BindingCodeChannelConnectionConfig)
|
||||||
|
|
||||||
|
def provider_status(self, provider: str) -> dict[str, bool]:
|
||||||
|
config = getattr(self, provider, None)
|
||||||
|
if config is None:
|
||||||
|
return {"enabled": False, "configured": False}
|
||||||
|
enabled = bool(config.enabled)
|
||||||
|
return {
|
||||||
|
"enabled": enabled,
|
||||||
|
"configured": enabled and bool(config.configured),
|
||||||
|
}
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
"""Configuration for memory mechanism."""
|
"""Configuration for memory mechanism."""
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
|
|||||||
le=8000,
|
le=8000,
|
||||||
description="Maximum tokens to use for memory injection",
|
description="Maximum tokens to use for memory injection",
|
||||||
)
|
)
|
||||||
|
token_counting: Literal["tiktoken", "char"] = Field(
|
||||||
|
default="tiktoken",
|
||||||
|
description=(
|
||||||
|
"Token counting strategy for memory-injection budgeting. "
|
||||||
|
"'tiktoken' is accurate but the encoding's BPE data may be "
|
||||||
|
"downloaded from a public network endpoint on first use, which "
|
||||||
|
"can block for a long time in network-restricted environments "
|
||||||
|
"(see issue #3402/#3429). 'char' uses a network-free "
|
||||||
|
"CJK-aware character-based estimate and never touches tiktoken."
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
# Global configuration instance
|
# Global configuration instance
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def make_safe_user_id(raw: str) -> str:
|
|||||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
||||||
if sanitized == raw:
|
if sanitized == raw:
|
||||||
return raw
|
return raw
|
||||||
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||||
return f"{sanitized}-{digest}"
|
return f"{sanitized}-{digest}"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,20 @@ from pydantic import BaseModel, ConfigDict, Field
|
|||||||
class VolumeMountConfig(BaseModel):
|
class VolumeMountConfig(BaseModel):
|
||||||
"""Configuration for a volume mount."""
|
"""Configuration for a volume mount."""
|
||||||
|
|
||||||
host_path: str = Field(..., description="Path on the host machine")
|
host_path: str = Field(
|
||||||
|
...,
|
||||||
|
description=(
|
||||||
|
"Source path for the mount. Resolution depends on the active provider: "
|
||||||
|
"``LocalSandboxProvider`` checks this path from the gateway process — in "
|
||||||
|
"``make dev`` that is the host machine, but in Docker deployments "
|
||||||
|
"(``make up`` / docker-compose) it is the path *inside* the "
|
||||||
|
"``deer-flow-gateway`` container, so the host directory must also be "
|
||||||
|
"bind-mounted into the gateway service for the mount to take effect. "
|
||||||
|
"``AioSandboxProvider`` (DooD) passes this value straight to ``docker -v`` "
|
||||||
|
"for the sandbox container, where it is resolved by the host Docker daemon "
|
||||||
|
"from the host machine's perspective."
|
||||||
|
),
|
||||||
|
)
|
||||||
container_path: str = Field(..., description="Path inside the container")
|
container_path: str = Field(..., description="Path inside the container")
|
||||||
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
read_only: bool = Field(default=False, description="Whether the mount is read-only")
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,21 @@
|
|||||||
|
"""User-owned IM channel connection persistence."""
|
||||||
|
|
||||||
|
from deerflow.persistence.channel_connections.model import (
|
||||||
|
ChannelConnectionRow,
|
||||||
|
ChannelConversationRow,
|
||||||
|
ChannelCredentialRow,
|
||||||
|
ChannelOAuthStateRow,
|
||||||
|
)
|
||||||
|
from deerflow.persistence.channel_connections.sql import (
|
||||||
|
ChannelConnectionRepository,
|
||||||
|
ChannelCredentialCipher,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChannelConnectionRepository",
|
||||||
|
"ChannelConnectionRow",
|
||||||
|
"ChannelConversationRow",
|
||||||
|
"ChannelCredentialCipher",
|
||||||
|
"ChannelCredentialRow",
|
||||||
|
"ChannelOAuthStateRow",
|
||||||
|
]
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
"""ORM models for user-owned IM channel connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from deerflow.persistence.base import Base
|
||||||
|
|
||||||
|
|
||||||
|
def _utc_now() -> datetime:
|
||||||
|
return datetime.now(UTC)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionRow(Base):
|
||||||
|
__tablename__ = "channel_connections"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
|
||||||
|
|
||||||
|
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
|
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||||
|
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
|
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||||
|
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
|
|
||||||
|
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||||
|
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||||
|
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"owner_user_id",
|
||||||
|
"provider",
|
||||||
|
"external_account_id",
|
||||||
|
"workspace_id",
|
||||||
|
name="uq_channel_connection_owner_provider_identity",
|
||||||
|
),
|
||||||
|
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCredentialRow(Base):
|
||||||
|
__tablename__ = "channel_credentials"
|
||||||
|
|
||||||
|
connection_id: Mapped[str] = mapped_column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||||
|
primary_key=True,
|
||||||
|
)
|
||||||
|
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||||
|
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelOAuthStateRow(Base):
|
||||||
|
__tablename__ = "channel_oauth_states"
|
||||||
|
|
||||||
|
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||||
|
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||||
|
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||||
|
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||||
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConversationRow(Base):
|
||||||
|
__tablename__ = "channel_conversations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
connection_id: Mapped[str] = mapped_column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
index=True,
|
||||||
|
)
|
||||||
|
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||||
|
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||||
|
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint(
|
||||||
|
"connection_id",
|
||||||
|
"external_conversation_id",
|
||||||
|
"external_topic_id",
|
||||||
|
name="uq_channel_conversation_connection_external",
|
||||||
|
),
|
||||||
|
)
|
||||||
@@ -0,0 +1,349 @@
|
|||||||
|
"""SQL repository for user-owned IM channel connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
from deerflow.persistence.channel_connections.model import (
|
||||||
|
ChannelConnectionRow,
|
||||||
|
ChannelConversationRow,
|
||||||
|
ChannelCredentialRow,
|
||||||
|
ChannelOAuthStateRow,
|
||||||
|
)
|
||||||
|
from deerflow.utils.time import coerce_iso
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelCredentialCipher:
|
||||||
|
"""Encrypts provider credentials before they are persisted."""
|
||||||
|
|
||||||
|
def __init__(self, fernet: Fernet) -> None:
|
||||||
|
self._fernet = fernet
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_key(cls, key: str) -> ChannelCredentialCipher:
|
||||||
|
digest = hashlib.sha256(key.encode("utf-8")).digest()
|
||||||
|
return cls(Fernet(base64.urlsafe_b64encode(digest)))
|
||||||
|
|
||||||
|
def encrypt_text(self, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
|
||||||
|
|
||||||
|
def decrypt_text(self, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
token = value.removeprefix("fernet:v1:")
|
||||||
|
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelConnectionRepository:
|
||||||
|
"""Persistence facade for channel connections, credentials, and conversations."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
*,
|
||||||
|
cipher: ChannelCredentialCipher | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.session_factory = session_factory
|
||||||
|
self._cipher = cipher
|
||||||
|
|
||||||
|
async def close(self) -> None:
|
||||||
|
from deerflow.persistence.engine import close_engine
|
||||||
|
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _new_id() -> str:
|
||||||
|
return uuid.uuid4().hex
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _normalize_optional_identity(value: str | None) -> str:
|
||||||
|
return value or ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _coerce_datetime(value: datetime | None) -> datetime | None:
|
||||||
|
if value is None or value.tzinfo is not None:
|
||||||
|
return value
|
||||||
|
return value.replace(tzinfo=UTC)
|
||||||
|
|
||||||
|
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
||||||
|
if value is None:
|
||||||
|
return None
|
||||||
|
if self._cipher is None:
|
||||||
|
raise RuntimeError("channel connection encryption key is required")
|
||||||
|
return self._cipher.encrypt_text(value)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
||||||
|
data = row.to_dict()
|
||||||
|
data["external_account_id"] = data["external_account_id"] or None
|
||||||
|
data["workspace_id"] = data["workspace_id"] or None
|
||||||
|
data["scopes"] = data.pop("scopes_json") or []
|
||||||
|
data["capabilities"] = data.pop("capabilities_json") or {}
|
||||||
|
data["metadata"] = data.pop("metadata_json") or {}
|
||||||
|
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
|
||||||
|
value = data.get(key)
|
||||||
|
if isinstance(value, datetime):
|
||||||
|
data[key] = coerce_iso(value)
|
||||||
|
return data
|
||||||
|
|
||||||
|
async def upsert_connection(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
owner_user_id: str,
|
||||||
|
provider: str,
|
||||||
|
external_account_id: str | None = None,
|
||||||
|
external_account_name: str | None = None,
|
||||||
|
workspace_id: str | None = None,
|
||||||
|
workspace_name: str | None = None,
|
||||||
|
bot_user_id: str | None = None,
|
||||||
|
scopes: list[str] | None = None,
|
||||||
|
capabilities: dict[str, Any] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str = "connected",
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
external_account_id_value = self._normalize_optional_identity(external_account_id)
|
||||||
|
workspace_id_value = self._normalize_optional_identity(workspace_id)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
stmt = select(ChannelConnectionRow).where(
|
||||||
|
ChannelConnectionRow.owner_user_id == owner_user_id,
|
||||||
|
ChannelConnectionRow.provider == provider,
|
||||||
|
ChannelConnectionRow.external_account_id == external_account_id_value,
|
||||||
|
ChannelConnectionRow.workspace_id == workspace_id_value,
|
||||||
|
)
|
||||||
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
row = ChannelConnectionRow(
|
||||||
|
id=self._new_id(),
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
external_account_id=external_account_id_value,
|
||||||
|
workspace_id=workspace_id_value,
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
|
||||||
|
row.status = status
|
||||||
|
row.external_account_name = external_account_name
|
||||||
|
row.workspace_name = workspace_name
|
||||||
|
row.bot_user_id = bot_user_id
|
||||||
|
row.scopes_json = list(scopes or [])
|
||||||
|
row.capabilities_json = dict(capabilities or {})
|
||||||
|
row.metadata_json = dict(metadata or {})
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(row)
|
||||||
|
return self._connection_to_dict(row)
|
||||||
|
|
||||||
|
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
|
||||||
|
return [self._connection_to_dict(row) for row in result.scalars()]
|
||||||
|
|
||||||
|
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
row = await session.get(ChannelConnectionRow, connection_id)
|
||||||
|
if row is None or row.owner_user_id != owner_user_id:
|
||||||
|
return False
|
||||||
|
|
||||||
|
row.status = "revoked"
|
||||||
|
credential = await session.get(ChannelCredentialRow, connection_id)
|
||||||
|
if credential is not None:
|
||||||
|
await session.delete(credential)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def store_credentials(
|
||||||
|
self,
|
||||||
|
connection_id: str,
|
||||||
|
*,
|
||||||
|
access_token: str | None,
|
||||||
|
refresh_token: str | None = None,
|
||||||
|
token_type: str | None = None,
|
||||||
|
expires_at: datetime | None = None,
|
||||||
|
refresh_expires_at: datetime | None = None,
|
||||||
|
extra: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
if self._cipher is None:
|
||||||
|
raise RuntimeError("channel connection encryption key is required")
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
row = await session.get(ChannelCredentialRow, connection_id)
|
||||||
|
if row is None:
|
||||||
|
row = ChannelCredentialRow(connection_id=connection_id)
|
||||||
|
session.add(row)
|
||||||
|
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
|
||||||
|
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
|
||||||
|
row.token_type = token_type
|
||||||
|
row.expires_at = expires_at
|
||||||
|
row.refresh_expires_at = refresh_expires_at
|
||||||
|
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
|
||||||
|
row.version = (row.version or 0) + 1
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
||||||
|
if self._cipher is None:
|
||||||
|
return None
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
row = await session.get(ChannelCredentialRow, connection_id)
|
||||||
|
if row is None:
|
||||||
|
return None
|
||||||
|
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
|
||||||
|
return {
|
||||||
|
"connection_id": row.connection_id,
|
||||||
|
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
|
||||||
|
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
|
||||||
|
"token_type": row.token_type,
|
||||||
|
"expires_at": self._coerce_datetime(row.expires_at),
|
||||||
|
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
|
||||||
|
"extra": json.loads(extra_raw) if extra_raw else {},
|
||||||
|
}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def hash_state(state: str) -> str:
|
||||||
|
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
async def create_oauth_state(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
owner_user_id: str,
|
||||||
|
provider: str,
|
||||||
|
state: str,
|
||||||
|
expires_at: datetime,
|
||||||
|
code_verifier: str | None = None,
|
||||||
|
nonce_hash: str | None = None,
|
||||||
|
redirect_after: str | None = None,
|
||||||
|
requested_scopes: list[str] | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
row = ChannelOAuthStateRow(
|
||||||
|
state_hash=self.hash_state(state),
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
||||||
|
nonce_hash=nonce_hash,
|
||||||
|
redirect_after=redirect_after,
|
||||||
|
requested_scopes_json=list(requested_scopes or []),
|
||||||
|
metadata_json=dict(metadata or {}),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
session.add(row)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(ChannelOAuthStateRow).where(
|
||||||
|
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||||
|
ChannelOAuthStateRow.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return len(list(result.scalars()))
|
||||||
|
|
||||||
|
async def consume_oauth_state(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
state: str,
|
||||||
|
now: datetime | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
current_time = now or datetime.now(UTC)
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
||||||
|
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
|
||||||
|
if row is None or row.provider != provider or row.consumed_at is not None:
|
||||||
|
await session.commit()
|
||||||
|
return None
|
||||||
|
expires_at = self._coerce_datetime(row.expires_at)
|
||||||
|
if expires_at is not None and expires_at < current_time:
|
||||||
|
await session.commit()
|
||||||
|
return None
|
||||||
|
|
||||||
|
row.consumed_at = current_time
|
||||||
|
await session.commit()
|
||||||
|
return {
|
||||||
|
"owner_user_id": row.owner_user_id,
|
||||||
|
"provider": row.provider,
|
||||||
|
"requested_scopes": row.requested_scopes_json or [],
|
||||||
|
"metadata": row.metadata_json or {},
|
||||||
|
"redirect_after": row.redirect_after,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def find_connection_by_external_identity(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
provider: str,
|
||||||
|
external_account_id: str,
|
||||||
|
workspace_id: str | None = None,
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(ChannelConnectionRow)
|
||||||
|
.where(
|
||||||
|
ChannelConnectionRow.provider == provider,
|
||||||
|
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
|
||||||
|
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
|
||||||
|
ChannelConnectionRow.status == "connected",
|
||||||
|
)
|
||||||
|
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
return self._connection_to_dict(row) if row is not None else None
|
||||||
|
|
||||||
|
async def set_thread_id(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
connection_id: str,
|
||||||
|
owner_user_id: str,
|
||||||
|
provider: str,
|
||||||
|
external_conversation_id: str,
|
||||||
|
thread_id: str,
|
||||||
|
external_topic_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
topic_id = external_topic_id or ""
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
stmt = select(ChannelConversationRow).where(
|
||||||
|
ChannelConversationRow.connection_id == connection_id,
|
||||||
|
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||||
|
ChannelConversationRow.external_topic_id == topic_id,
|
||||||
|
)
|
||||||
|
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
row = ChannelConversationRow(
|
||||||
|
id=self._new_id(),
|
||||||
|
connection_id=connection_id,
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
external_conversation_id=external_conversation_id,
|
||||||
|
external_topic_id=topic_id,
|
||||||
|
thread_id=thread_id,
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
else:
|
||||||
|
row.thread_id = thread_id
|
||||||
|
row.owner_user_id = owner_user_id
|
||||||
|
row.provider = provider
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def get_thread_id(
|
||||||
|
self,
|
||||||
|
connection_id: str,
|
||||||
|
external_conversation_id: str,
|
||||||
|
external_topic_id: str | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
async with self.session_factory() as session:
|
||||||
|
stmt = select(ChannelConversationRow.thread_id).where(
|
||||||
|
ChannelConversationRow.connection_id == connection_id,
|
||||||
|
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||||
|
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
||||||
|
)
|
||||||
|
return (await session.execute(stmt)).scalar_one_or_none()
|
||||||
@@ -14,10 +14,26 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
|||||||
there is no matching entity directory.
|
there is no matching entity directory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from deerflow.persistence.channel_connections.model import (
|
||||||
|
ChannelConnectionRow,
|
||||||
|
ChannelConversationRow,
|
||||||
|
ChannelCredentialRow,
|
||||||
|
ChannelOAuthStateRow,
|
||||||
|
)
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.persistence.user.model import UserRow
|
from deerflow.persistence.user.model import UserRow
|
||||||
|
|
||||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
__all__ = [
|
||||||
|
"ChannelConnectionRow",
|
||||||
|
"ChannelConversationRow",
|
||||||
|
"ChannelCredentialRow",
|
||||||
|
"ChannelOAuthStateRow",
|
||||||
|
"FeedbackRow",
|
||||||
|
"RunEventRow",
|
||||||
|
"RunRow",
|
||||||
|
"ThreadMetaRow",
|
||||||
|
"UserRow",
|
||||||
|
]
|
||||||
|
|||||||
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
"""Move a thread metadata row to a new owner.
|
||||||
|
|
||||||
|
Intended for trusted internal repair/migration paths. No-op if the
|
||||||
|
row does not exist or the caller fails the owner check.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||||
|
|||||||
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
record["updated_at"] = now_iso()
|
record["updated_at"] = now_iso()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
|
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
|
||||||
|
if record is None:
|
||||||
|
return
|
||||||
|
record["user_id"] = owner_user_id
|
||||||
|
record["updated_at"] = now_iso()
|
||||||
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||||
if record is None:
|
if record is None:
|
||||||
|
|||||||
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row.updated_at = datetime.now(UTC)
|
row.updated_at = datetime.now(UTC)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def update_owner(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
owner_user_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> None:
|
||||||
|
"""Move a thread metadata row to ``owner_user_id``."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
|
||||||
|
async with self._sf() as session:
|
||||||
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
|
return
|
||||||
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
|||||||
@@ -164,7 +164,18 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
metadata={"caller": caller, **(metadata or {})},
|
metadata={"caller": caller, **(metadata or {})},
|
||||||
)
|
)
|
||||||
|
|
||||||
def on_chain_end(self, outputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_chain_end(
|
||||||
|
self,
|
||||||
|
outputs: Any,
|
||||||
|
*,
|
||||||
|
run_id: UUID,
|
||||||
|
parent_run_id: UUID | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
# Nested chain ends fire for internal graph nodes; only the root chain
|
||||||
|
# represents the user-visible run lifecycle.
|
||||||
|
if parent_run_id is not None:
|
||||||
|
return
|
||||||
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
self._put(event_type="run.end", category="outputs", content=outputs, metadata={"status": "success"})
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ class RunRecord:
|
|||||||
multitask_strategy: str = "reject"
|
multitask_strategy: str = "reject"
|
||||||
metadata: dict = field(default_factory=dict)
|
metadata: dict = field(default_factory=dict)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
user_id: str | None = None
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
task: asyncio.Task | None = field(default=None, repr=False)
|
task: asyncio.Task | None = field(default=None, repr=False)
|
||||||
@@ -124,7 +125,7 @@ class RunManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||||
return {
|
payload = {
|
||||||
"thread_id": record.thread_id,
|
"thread_id": record.thread_id,
|
||||||
"assistant_id": record.assistant_id,
|
"assistant_id": record.assistant_id,
|
||||||
"status": record.status.value,
|
"status": record.status.value,
|
||||||
@@ -135,6 +136,9 @@ class RunManager:
|
|||||||
"created_at": record.created_at,
|
"created_at": record.created_at,
|
||||||
"model_name": record.model_name,
|
"model_name": record.model_name,
|
||||||
}
|
}
|
||||||
|
if record.user_id is not None:
|
||||||
|
payload["user_id"] = record.user_id
|
||||||
|
return payload
|
||||||
|
|
||||||
async def _call_store_with_retry(
|
async def _call_store_with_retry(
|
||||||
self,
|
self,
|
||||||
@@ -241,6 +245,7 @@ class RunManager:
|
|||||||
kwargs=row.get("kwargs") or {},
|
kwargs=row.get("kwargs") or {},
|
||||||
created_at=row.get("created_at") or "",
|
created_at=row.get("created_at") or "",
|
||||||
updated_at=row.get("updated_at") or "",
|
updated_at=row.get("updated_at") or "",
|
||||||
|
user_id=row.get("user_id"),
|
||||||
error=row.get("error"),
|
error=row.get("error"),
|
||||||
model_name=row.get("model_name"),
|
model_name=row.get("model_name"),
|
||||||
store_only=True,
|
store_only=True,
|
||||||
@@ -320,6 +325,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
user_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -333,6 +339,7 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
|
user_id=user_id,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
@@ -504,6 +511,7 @@ class RunManager:
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -546,6 +554,7 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
|
user_id=user_id,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
|||||||
@@ -147,7 +147,17 @@ class LocalSandboxProvider(SandboxProvider):
|
|||||||
mount.container_path,
|
mount.container_path,
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
# Ensure the host path exists before adding mapping
|
# Ensure the host path exists before adding mapping.
|
||||||
|
#
|
||||||
|
# ``host_path`` is resolved against the filesystem of the
|
||||||
|
# process running this provider — for ``make dev`` that is
|
||||||
|
# the host machine, but for ``make up`` it is the
|
||||||
|
# ``deer-flow-gateway`` container, so any host path that
|
||||||
|
# isn't bind-mounted into the gateway image will be missing
|
||||||
|
# here. Skipping silently makes this a high-cost-to-debug
|
||||||
|
# silent failure (sandbox skill / tool reads an empty dir
|
||||||
|
# instead of the configured mount), so escalate to ERROR
|
||||||
|
# and include actionable guidance. See #3244.
|
||||||
if host_path.exists():
|
if host_path.exists():
|
||||||
mappings.append(
|
mappings.append(
|
||||||
PathMapping(
|
PathMapping(
|
||||||
@@ -157,10 +167,16 @@ class LocalSandboxProvider(SandboxProvider):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning(
|
logger.error(
|
||||||
"Mount host_path does not exist, skipping: %s -> %s",
|
"sandbox.mounts entry %s -> %s ignored: host_path %s does not exist from the "
|
||||||
|
"perspective of the gateway process. In Docker deployments (make up / docker-compose), "
|
||||||
|
"this path must also be bind-mounted into the gateway container — add a matching "
|
||||||
|
"volume entry under services.gateway.volumes in docker/docker-compose.yaml (and use "
|
||||||
|
"the in-container path here), or run in local mode (make dev) where the gateway sees "
|
||||||
|
"the host filesystem directly.",
|
||||||
mount.host_path,
|
mount.host_path,
|
||||||
mount.container_path,
|
mount.container_path,
|
||||||
|
mount.host_path,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Log but don't fail if config loading fails
|
# Log but don't fail if config loading fails
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ dependencies = [
|
|||||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||||
"aiosqlite>=0.19",
|
"aiosqlite>=0.19",
|
||||||
"alembic>=1.13",
|
"alembic>=1.13",
|
||||||
|
"cryptography>=43.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|||||||
@@ -0,0 +1,251 @@
|
|||||||
|
"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from app.channels.message_bus import InboundMessage, MessageBus
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_repo(tmp_path, name: str):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
return ChannelConnectionRepository(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
|
async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None:
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id=owner_user_id,
|
||||||
|
provider=provider,
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_feishu_connect_command_binds_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.feishu import FeishuChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path, "feishu")
|
||||||
|
state = "feishu-bind-code"
|
||||||
|
await _seed_state(repo, "feishu", state)
|
||||||
|
channel = FeishuChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"app_id": "app", "app_secret": "secret", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._reply_card = AsyncMock()
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(
|
||||||
|
message_id="om-message-1",
|
||||||
|
chat_id="oc-chat-1",
|
||||||
|
user_id="ou-user-1",
|
||||||
|
code=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "feishu"
|
||||||
|
assert connections[0]["external_account_id"] == "ou-user-1"
|
||||||
|
assert connections[0]["workspace_id"] == "oc-chat-1"
|
||||||
|
channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.")
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_dingtalk_connect_command_binds_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path, "dingtalk")
|
||||||
|
state = "dingtalk-bind-code"
|
||||||
|
await _seed_state(repo, "dingtalk", state)
|
||||||
|
channel = DingTalkChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"client_id": "client", "client_secret": "secret", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._send_connection_reply = AsyncMock()
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(
|
||||||
|
conversation_type=_CONVERSATION_TYPE_GROUP,
|
||||||
|
sender_staff_id="staff-user-1",
|
||||||
|
sender_nick="Alice",
|
||||||
|
conversation_id="cid-group-1",
|
||||||
|
code=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "dingtalk"
|
||||||
|
assert connections[0]["external_account_id"] == "staff-user-1"
|
||||||
|
assert connections[0]["external_account_name"] == "Alice"
|
||||||
|
assert connections[0]["workspace_id"] == "cid-group-1"
|
||||||
|
channel._send_connection_reply.assert_awaited_once()
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wechat_connect_command_binds_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.wechat import WechatChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path, "wechat")
|
||||||
|
state = "wechat-bind-code"
|
||||||
|
await _seed_state(repo, "wechat", state)
|
||||||
|
channel = WechatChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_token": "token", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._send_connection_reply = AsyncMock()
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(
|
||||||
|
chat_id="wx-user-1",
|
||||||
|
context_token="ctx-1",
|
||||||
|
code=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "wechat"
|
||||||
|
assert connections[0]["external_account_id"] == "wx-user-1"
|
||||||
|
assert connections[0]["workspace_id"] == "wx-user-1"
|
||||||
|
channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.")
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_wecom_connect_command_binds_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.wecom import WeComChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path, "wecom")
|
||||||
|
state = "wecom-bind-code"
|
||||||
|
await _seed_state(repo, "wecom", state)
|
||||||
|
channel = WeComChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._ws_client = MagicMock()
|
||||||
|
channel._ws_client.reply = AsyncMock()
|
||||||
|
frame = {"body": {"aibotid": "bot-1", "chattype": "single"}}
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(
|
||||||
|
frame=frame,
|
||||||
|
user_id="wecom-user-1",
|
||||||
|
code=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "wecom"
|
||||||
|
assert connections[0]["external_account_id"] == "wecom-user-1"
|
||||||
|
assert connections[0]["workspace_id"] == "bot-1"
|
||||||
|
channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}})
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_additional_channels_attach_owner_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||||
|
from app.channels.feishu import FeishuChannel
|
||||||
|
from app.channels.wechat import WechatChannel
|
||||||
|
from app.channels.wecom import WeComChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path, "additional-identity")
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="feishu",
|
||||||
|
external_account_id="ou-user-1",
|
||||||
|
workspace_id="oc-chat-1",
|
||||||
|
)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="dingtalk",
|
||||||
|
external_account_id="staff-user-1",
|
||||||
|
workspace_id="cid-group-1",
|
||||||
|
)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="wechat",
|
||||||
|
external_account_id="wx-user-1",
|
||||||
|
workspace_id="wx-user-1",
|
||||||
|
)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="wecom",
|
||||||
|
external_account_id="wecom-user-1",
|
||||||
|
workspace_id="bot-1",
|
||||||
|
)
|
||||||
|
|
||||||
|
cases = [
|
||||||
|
(
|
||||||
|
FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||||
|
InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="dingtalk",
|
||||||
|
chat_id="cid-group-1",
|
||||||
|
user_id="staff-user-1",
|
||||||
|
text="hello",
|
||||||
|
metadata={
|
||||||
|
"conversation_type": _CONVERSATION_TYPE_GROUP,
|
||||||
|
"conversation_id": "cid-group-1",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
WechatChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||||
|
InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
WeComChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="wecom",
|
||||||
|
chat_id="wecom-user-1",
|
||||||
|
user_id="wecom-user-1",
|
||||||
|
text="hello",
|
||||||
|
metadata={"aibotid": "bot-1"},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
for channel, inbound in cases:
|
||||||
|
attached = await channel._attach_connection_identity(inbound)
|
||||||
|
assert attached.owner_user_id == "deerflow-user-1"
|
||||||
|
assert attached.connection_id
|
||||||
|
assert (
|
||||||
|
attached.workspace_id
|
||||||
|
== {
|
||||||
|
"feishu": "oc-chat-1",
|
||||||
|
"dingtalk": "cid-group-1",
|
||||||
|
"wechat": "wx-user-1",
|
||||||
|
"wecom": "bot-1",
|
||||||
|
}[channel.name]
|
||||||
|
)
|
||||||
|
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
@@ -4,6 +4,7 @@ import pytest
|
|||||||
from starlette.testclient import TestClient
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||||
|
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||||
|
|
||||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -38,6 +39,8 @@ def test_public_paths(path: str):
|
|||||||
"/api/threads/123/uploads",
|
"/api/threads/123/uploads",
|
||||||
"/api/agents",
|
"/api/agents",
|
||||||
"/api/channels",
|
"/api/channels",
|
||||||
|
"/api/channels/providers",
|
||||||
|
"/api/channels/slack/connect",
|
||||||
"/api/runs/stream",
|
"/api/runs/stream",
|
||||||
"/api/threads/123/runs",
|
"/api/threads/123/runs",
|
||||||
"/api/v1/auth/me",
|
"/api/v1/auth/me",
|
||||||
@@ -88,7 +91,9 @@ def test_unknown_api_path_is_protected():
|
|||||||
|
|
||||||
def _make_app():
|
def _make_app():
|
||||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI, Request
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(AuthMiddleware)
|
app.add_middleware(AuthMiddleware)
|
||||||
@@ -98,8 +103,16 @@ def _make_app():
|
|||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
@app.get("/api/v1/auth/me")
|
@app.get("/api/v1/auth/me")
|
||||||
async def auth_me():
|
async def auth_me(request: Request):
|
||||||
return {"id": "1", "email": "test@test.com"}
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"email": user.email,
|
||||||
|
"system_role": user.system_role,
|
||||||
|
"needs_setup": user.needs_setup,
|
||||||
|
}
|
||||||
|
|
||||||
@app.get("/api/v1/auth/setup-status")
|
@app.get("/api/v1/auth/setup-status")
|
||||||
async def setup_status():
|
async def setup_status():
|
||||||
@@ -109,6 +122,29 @@ def _make_app():
|
|||||||
async def models_get():
|
async def models_get():
|
||||||
return {"models": []}
|
return {"models": []}
|
||||||
|
|
||||||
|
@app.get("/api/whoami")
|
||||||
|
async def whoami(request: Request):
|
||||||
|
user = request.state.user
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"email": getattr(user, "email", None),
|
||||||
|
"system_role": getattr(user, "system_role", None),
|
||||||
|
"context_user_id": get_effective_user_id(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@app.get("/api/current-user-from-dep")
|
||||||
|
async def current_user_from_dep(request: Request):
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
state_user = request.state.user
|
||||||
|
return {
|
||||||
|
"id": str(user.id),
|
||||||
|
"state_id": str(state_user.id),
|
||||||
|
"auth_source": request.state.auth_source,
|
||||||
|
"context_user_id": get_effective_user_id(),
|
||||||
|
}
|
||||||
|
|
||||||
@app.put("/api/mcp/config")
|
@app.put("/api/mcp/config")
|
||||||
async def mcp_put():
|
async def mcp_put():
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
@@ -132,8 +168,24 @@ def _make_app():
|
|||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_auth_csrf_app():
|
||||||
|
"""Create a minimal app with production middleware ordering."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
|
@app.post("/api/threads/abc/runs/stream")
|
||||||
|
async def protected_mutation():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def client():
|
def client(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "")
|
||||||
return TestClient(_make_app())
|
return TestClient(_make_app())
|
||||||
|
|
||||||
|
|
||||||
@@ -161,11 +213,145 @@ def test_protected_path_no_cookie_returns_401(client):
|
|||||||
assert body["detail"]["code"] == "not_authenticated"
|
assert body["detail"]["code"] == "not_authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/models")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"models": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/whoami")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "default",
|
||||||
|
"email": "default@test.local",
|
||||||
|
"system_role": "admin",
|
||||||
|
"context_user_id": "default",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/v1/auth/me")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "default",
|
||||||
|
"email": "default@test.local",
|
||||||
|
"system_role": "admin",
|
||||||
|
"needs_setup": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_does_not_clobber_valid_session_cookie(monkeypatch):
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
async def fake_current_user(request):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id="session-user",
|
||||||
|
email="session@test.local",
|
||||||
|
system_role="user",
|
||||||
|
needs_setup=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setattr("app.gateway.deps.get_current_user_from_request", fake_current_user)
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/whoami", cookies={"access_token": "valid-session"})
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": "session-user",
|
||||||
|
"email": "session@test.local",
|
||||||
|
"system_role": "user",
|
||||||
|
"context_user_id": "session-user",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_does_not_clobber_internal_auth_identity(monkeypatch):
|
||||||
|
from app.gateway.internal_auth import create_internal_auth_headers
|
||||||
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get(
|
||||||
|
"/api/current-user-from-dep",
|
||||||
|
headers=create_internal_auth_headers(),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {
|
||||||
|
"id": DEFAULT_USER_ID,
|
||||||
|
"state_id": DEFAULT_USER_ID,
|
||||||
|
"auth_source": "internal",
|
||||||
|
"context_user_id": DEFAULT_USER_ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_skips_csrf_for_state_changing_requests(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
client = TestClient(_make_auth_csrf_app())
|
||||||
|
|
||||||
|
res = client.post("/api/threads/abc/runs/stream")
|
||||||
|
|
||||||
|
assert res.status_code == 200
|
||||||
|
assert res.json() == {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_is_ignored_in_explicit_production_env(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setenv("DEER_FLOW_ENV", "production")
|
||||||
|
client = TestClient(_make_app())
|
||||||
|
|
||||||
|
res = client.get("/api/models")
|
||||||
|
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
|
||||||
|
monkeypatch.delenv("ENVIRONMENT", raising=False)
|
||||||
|
|
||||||
|
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
|
|
||||||
|
assert "authentication is bypassed" in caplog.text
|
||||||
|
assert "default" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
|
||||||
|
from app.gateway.auth_disabled import warn_if_auth_disabled_enabled
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
monkeypatch.setenv("ENVIRONMENT", "production")
|
||||||
|
|
||||||
|
with caplog.at_level("WARNING", logger="app.gateway.auth_disabled"):
|
||||||
|
warn_if_auth_disabled_enabled()
|
||||||
|
|
||||||
|
assert "authentication is bypassed" not in caplog.text
|
||||||
|
|
||||||
|
|
||||||
def test_protected_path_with_junk_cookie_rejected(client):
|
def test_protected_path_with_junk_cookie_rejected(client):
|
||||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||||
tokens through to the route handler."""
|
tokens through to the route handler."""
|
||||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
client.cookies.set("access_token", "some-token")
|
||||||
|
res = client.get("/api/models")
|
||||||
assert res.status_code == 401
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"""Tests for user-facing IM channel connection configuration."""
|
||||||
|
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_connections_disabled_by_default():
|
||||||
|
config = ChannelConnectionsConfig()
|
||||||
|
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.slack.enabled is False
|
||||||
|
assert config.telegram.enabled is False
|
||||||
|
assert config.discord.enabled is False
|
||||||
|
assert config.feishu.enabled is False
|
||||||
|
assert config.dingtalk.enabled is False
|
||||||
|
assert config.wechat.enabled is False
|
||||||
|
assert config.wecom.enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key():
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"telegram": {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_username": "deerflow_bot",
|
||||||
|
},
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
"discord": {"enabled": True},
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
"dingtalk": {"enabled": True},
|
||||||
|
"wechat": {"enabled": True},
|
||||||
|
"wecom": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.provider_status("telegram") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("slack") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("discord") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("feishu") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("dingtalk") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("wechat") == {"enabled": True, "configured": True}
|
||||||
|
assert config.provider_status("wecom") == {"enabled": True, "configured": True}
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_status_reports_disabled_and_unknown_providers():
|
||||||
|
config = ChannelConnectionsConfig.model_validate({"enabled": True})
|
||||||
|
|
||||||
|
assert config.provider_status("slack") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("telegram") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("discord") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("feishu") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("dingtalk") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("wechat") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("wecom") == {"enabled": False, "configured": False}
|
||||||
|
assert config.provider_status("unknown") == {"enabled": False, "configured": False}
|
||||||
@@ -0,0 +1,226 @@
|
|||||||
|
"""Tests for per-user IM channel connection persistence."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from deerflow.persistence.channel_connections import (
|
||||||
|
ChannelConnectionRepository,
|
||||||
|
ChannelConnectionRow,
|
||||||
|
ChannelCredentialCipher,
|
||||||
|
ChannelCredentialRow,
|
||||||
|
ChannelOAuthStateRow,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def repo(tmp_path):
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
|
||||||
|
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
|
||||||
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||||
|
try:
|
||||||
|
yield ChannelConnectionRepository(
|
||||||
|
get_session_factory(),
|
||||||
|
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelConnectionRepository:
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_connections_are_listed_per_owner(self, repo):
|
||||||
|
alice = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-alice",
|
||||||
|
external_account_name="Alice",
|
||||||
|
workspace_id="T1",
|
||||||
|
workspace_name="Team One",
|
||||||
|
scopes=["chat:write"],
|
||||||
|
)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="bob",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-bob",
|
||||||
|
external_account_name="Bob",
|
||||||
|
workspace_id="T1",
|
||||||
|
workspace_name="Team One",
|
||||||
|
scopes=["chat:write"],
|
||||||
|
)
|
||||||
|
|
||||||
|
results = await repo.list_connections("alice")
|
||||||
|
|
||||||
|
assert [item["id"] for item in results] == [alice["id"]]
|
||||||
|
assert results[0]["owner_user_id"] == "alice"
|
||||||
|
assert results[0]["provider"] == "slack"
|
||||||
|
assert results[0]["scopes"] == ["chat:write"]
|
||||||
|
assert "encrypted_access_token" not in results[0]
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
|
||||||
|
first = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
external_account_name="Alice",
|
||||||
|
workspace_id=None,
|
||||||
|
workspace_name=None,
|
||||||
|
status="pending",
|
||||||
|
)
|
||||||
|
second = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
external_account_name="Alice Telegram",
|
||||||
|
workspace_id=None,
|
||||||
|
workspace_name=None,
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert second["id"] == first["id"]
|
||||||
|
assert second["status"] == "connected"
|
||||||
|
assert second["external_account_name"] == "Alice Telegram"
|
||||||
|
assert len(await repo.list_connections("alice")) == 1
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-alice",
|
||||||
|
workspace_id="T1",
|
||||||
|
)
|
||||||
|
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
||||||
|
|
||||||
|
await repo.store_credentials(
|
||||||
|
connection["id"],
|
||||||
|
access_token="xoxb-secret-access-token",
|
||||||
|
refresh_token="secret-refresh-token",
|
||||||
|
token_type="Bearer",
|
||||||
|
expires_at=expires_at,
|
||||||
|
extra={"bot_user_id": "B123"},
|
||||||
|
)
|
||||||
|
|
||||||
|
async with repo.session_factory() as session:
|
||||||
|
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
|
||||||
|
assert row.encrypted_access_token is not None
|
||||||
|
assert "xoxb-secret-access-token" not in row.encrypted_access_token
|
||||||
|
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
|
||||||
|
assert "B123" not in (row.encrypted_extra_json or "")
|
||||||
|
|
||||||
|
credentials = await repo.get_credentials(connection["id"])
|
||||||
|
|
||||||
|
assert credentials is not None
|
||||||
|
assert credentials["access_token"] == "xoxb-secret-access-token"
|
||||||
|
assert credentials["refresh_token"] == "secret-refresh-token"
|
||||||
|
assert credentials["token_type"] == "Bearer"
|
||||||
|
assert credentials["expires_at"] == expires_at
|
||||||
|
assert credentials["extra"] == {"bot_user_id": "B123"}
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_conversations_are_scoped_by_connection(self, repo):
|
||||||
|
alice = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-alice",
|
||||||
|
workspace_id="T1",
|
||||||
|
)
|
||||||
|
bob = await repo.upsert_connection(
|
||||||
|
owner_user_id="bob",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-bob",
|
||||||
|
workspace_id="T1",
|
||||||
|
)
|
||||||
|
|
||||||
|
await repo.set_thread_id(
|
||||||
|
connection_id=alice["id"],
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
external_conversation_id="C-shared",
|
||||||
|
external_topic_id="1710000000.000100",
|
||||||
|
thread_id="thread-alice",
|
||||||
|
)
|
||||||
|
await repo.set_thread_id(
|
||||||
|
connection_id=bob["id"],
|
||||||
|
owner_user_id="bob",
|
||||||
|
provider="slack",
|
||||||
|
external_conversation_id="C-shared",
|
||||||
|
external_topic_id="1710000000.000100",
|
||||||
|
thread_id="thread-bob",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||||
|
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
)
|
||||||
|
await repo.store_credentials(connection["id"], access_token="secret-token")
|
||||||
|
|
||||||
|
disconnected = await repo.disconnect_connection(
|
||||||
|
connection_id=connection["id"],
|
||||||
|
owner_user_id="alice",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert disconnected is True
|
||||||
|
async with repo.session_factory() as session:
|
||||||
|
connection_row = await session.get(ChannelConnectionRow, connection["id"])
|
||||||
|
credential_row = await session.get(ChannelCredentialRow, connection["id"])
|
||||||
|
assert connection_row is not None
|
||||||
|
assert connection_row.status == "revoked"
|
||||||
|
assert credential_row is None
|
||||||
|
assert (
|
||||||
|
await repo.find_connection_by_external_identity(
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
)
|
||||||
|
is None
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_disconnect_connection_is_owner_scoped(self, repo):
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
)
|
||||||
|
|
||||||
|
disconnected = await repo.disconnect_connection(
|
||||||
|
connection_id=connection["id"],
|
||||||
|
owner_user_id="bob",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert disconnected is False
|
||||||
|
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_consume_oauth_state_deletes_expired_states(self, repo):
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
state="expired-state",
|
||||||
|
expires_at=now - timedelta(minutes=1),
|
||||||
|
)
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
state="active-state",
|
||||||
|
expires_at=now + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
|
||||||
|
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
|
||||||
|
|
||||||
|
assert consumed is None
|
||||||
|
async with repo.session_factory() as session:
|
||||||
|
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||||
|
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
||||||
@@ -0,0 +1,707 @@
|
|||||||
|
"""Router tests for browser-connectable IM channels."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from _router_auth_helpers import make_authed_test_app
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.routers import channel_connections
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
|
||||||
|
def _user() -> User:
|
||||||
|
return User(
|
||||||
|
id=UUID("11111111-2222-3333-4444-555555555555"),
|
||||||
|
email="alice@example.com",
|
||||||
|
password_hash="x",
|
||||||
|
system_role="user",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_repo(tmp_path):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||||
|
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
return ChannelConnectionRepository(get_session_factory())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app(
|
||||||
|
config: ChannelConnectionsConfig,
|
||||||
|
repo,
|
||||||
|
channels_config: dict | None = None,
|
||||||
|
*,
|
||||||
|
runtime_config_store: ChannelRuntimeConfigStore | None = None,
|
||||||
|
set_channels_config_state: bool = True,
|
||||||
|
):
|
||||||
|
app = make_authed_test_app(user_factory=_user)
|
||||||
|
app.state.channel_connections_config = config
|
||||||
|
app.state.channel_connection_repo = repo
|
||||||
|
if set_channels_config_state:
|
||||||
|
app.state.channels_config = channels_config or {}
|
||||||
|
if runtime_config_store is not None:
|
||||||
|
app.state.channel_runtime_config_store = runtime_config_store
|
||||||
|
app.include_router(channel_connections.router)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _enabled_connections_config() -> ChannelConnectionsConfig:
|
||||||
|
return ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
"discord": {"enabled": True},
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
"dingtalk": {"enabled": True},
|
||||||
|
"wechat": {"enabled": True},
|
||||||
|
"wecom": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _channels_config() -> dict:
|
||||||
|
return {
|
||||||
|
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||||
|
"slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"},
|
||||||
|
"discord": {"enabled": True, "bot_token": "discord-bot"},
|
||||||
|
"feishu": {"enabled": True, "app_id": "feishu-app", "app_secret": "feishu-secret"},
|
||||||
|
"dingtalk": {"enabled": True, "client_id": "dingtalk-client", "client_secret": "dingtalk-secret"},
|
||||||
|
"wechat": {"enabled": True, "bot_token": "wechat-token"},
|
||||||
|
"wecom": {"enabled": True, "bot_id": "wecom-bot", "bot_secret": "wecom-secret"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_only_returns_enabled_channels_and_setup_fields(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
"discord": {"enabled": False},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
app = _make_app(config, repo, {})
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["enabled"] is True
|
||||||
|
assert [provider["provider"] for provider in body["providers"]] == ["slack"]
|
||||||
|
assert body["providers"][0]["configured"] is False
|
||||||
|
assert body["providers"][0]["connectable"] is False
|
||||||
|
assert body["providers"][0]["credential_fields"] == [
|
||||||
|
{
|
||||||
|
"name": "bot_token",
|
||||||
|
"label": "Bot token",
|
||||||
|
"type": "password",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "app_token",
|
||||||
|
"label": "App token",
|
||||||
|
"type": "password",
|
||||||
|
"required": True,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_uses_existing_channels_config(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["enabled"] is True
|
||||||
|
by_provider = {item["provider"]: item for item in body["providers"]}
|
||||||
|
assert set(by_provider) == {"telegram", "slack", "discord", "feishu", "dingtalk", "wechat", "wecom"}
|
||||||
|
assert by_provider["telegram"]["configured"] is True
|
||||||
|
assert by_provider["telegram"]["auth_mode"] == "deep_link"
|
||||||
|
assert by_provider["telegram"]["credential_values"] == {
|
||||||
|
"bot_token": "********",
|
||||||
|
"bot_username": "deerflow_bot",
|
||||||
|
}
|
||||||
|
assert by_provider["slack"]["configured"] is True
|
||||||
|
assert by_provider["slack"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["slack"]["connection_status"] == "connected"
|
||||||
|
assert by_provider["slack"]["credential_values"] == {
|
||||||
|
"bot_token": "********",
|
||||||
|
"app_token": "********",
|
||||||
|
}
|
||||||
|
assert by_provider["discord"]["configured"] is True
|
||||||
|
assert by_provider["discord"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["discord"]["credential_values"] == {"bot_token": "********"}
|
||||||
|
assert by_provider["feishu"]["configured"] is True
|
||||||
|
assert by_provider["feishu"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["feishu"]["connection_status"] == "connected"
|
||||||
|
assert by_provider["feishu"]["credential_values"] == {
|
||||||
|
"app_id": "feishu-app",
|
||||||
|
"app_secret": "********",
|
||||||
|
}
|
||||||
|
assert by_provider["dingtalk"]["configured"] is True
|
||||||
|
assert by_provider["dingtalk"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["dingtalk"]["credential_values"] == {
|
||||||
|
"client_id": "dingtalk-client",
|
||||||
|
"client_secret": "********",
|
||||||
|
}
|
||||||
|
assert by_provider["wechat"]["configured"] is True
|
||||||
|
assert by_provider["wechat"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["wechat"]["credential_values"] == {"bot_token": "********"}
|
||||||
|
assert by_provider["wecom"]["configured"] is True
|
||||||
|
assert by_provider["wecom"]["auth_mode"] == "binding_code"
|
||||||
|
assert by_provider["wecom"]["credential_values"] == {
|
||||||
|
"bot_id": "wecom-bot",
|
||||||
|
"bot_secret": "********",
|
||||||
|
}
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}})
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["telegram"]["configured"] is True
|
||||||
|
assert by_provider["slack"]["configured"] is False
|
||||||
|
assert by_provider["slack"]["connectable"] is False
|
||||||
|
assert "Slack credentials" in by_provider["slack"]["unavailable_reason"]
|
||||||
|
assert by_provider["discord"]["configured"] is False
|
||||||
|
assert "Discord credentials" in by_provider["discord"]["unavailable_reason"]
|
||||||
|
assert by_provider["feishu"]["configured"] is False
|
||||||
|
assert "Feishu credentials" in by_provider["feishu"]["unavailable_reason"]
|
||||||
|
assert by_provider["dingtalk"]["configured"] is False
|
||||||
|
assert "DingTalk credentials" in by_provider["dingtalk"]["unavailable_reason"]
|
||||||
|
assert by_provider["wechat"]["configured"] is False
|
||||||
|
assert "WeChat credentials" in by_provider["wechat"]["unavailable_reason"]
|
||||||
|
assert by_provider["wecom"]["configured"] is False
|
||||||
|
assert "WeCom credentials" in by_provider["wecom"]["unavailable_reason"]
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypatch):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
service = SimpleNamespace(
|
||||||
|
get_status=lambda: {
|
||||||
|
"service_running": True,
|
||||||
|
"channels": {
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"running": False,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["feishu"]["configured"] is True
|
||||||
|
assert by_provider["feishu"]["connectable"] is False
|
||||||
|
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||||
|
assert "configured but is not running" in by_provider["feishu"]["unavailable_reason"]
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connections():
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-old",
|
||||||
|
workspace_id="T-old",
|
||||||
|
status="revoked",
|
||||||
|
)
|
||||||
|
await anyio.sleep(0.01)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-new",
|
||||||
|
workspace_id="T-new",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(seed_connections)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["connection_status"] == "connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_connections_returns_current_user_connections_only(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connections():
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
external_account_name="Alice",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id="other-user",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="99",
|
||||||
|
external_account_name="Bob",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(seed_connections)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/connections")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert len(body["connections"]) == 1
|
||||||
|
assert body["connections"][0]["provider"] == "telegram"
|
||||||
|
assert body["connections"][0]["external_account_id"] == "42"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/channels/telegram/connect")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["provider"] == "telegram"
|
||||||
|
assert body["mode"] == "deep_link"
|
||||||
|
assert body["url"].startswith("https://t.me/deerflow_bot?start=")
|
||||||
|
assert body["code"]
|
||||||
|
assert "/start" in body["instruction"]
|
||||||
|
|
||||||
|
async def count_states():
|
||||||
|
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram")
|
||||||
|
|
||||||
|
assert anyio.run(count_states) == 1
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/channels/slack/connect")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["provider"] == "slack"
|
||||||
|
assert body["mode"] == "binding_code"
|
||||||
|
assert body["url"] is None
|
||||||
|
assert len(body["code"]) >= 22
|
||||||
|
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
|
||||||
|
|
||||||
|
async def count_states():
|
||||||
|
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
|
||||||
|
|
||||||
|
assert anyio.run(count_states) == 1
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/channels/discord/connect")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["provider"] == "discord"
|
||||||
|
assert body["mode"] == "binding_code"
|
||||||
|
assert body["url"] is None
|
||||||
|
assert body["code"]
|
||||||
|
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot."
|
||||||
|
|
||||||
|
async def count_states():
|
||||||
|
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord")
|
||||||
|
|
||||||
|
assert anyio.run(count_states) == 1
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_existing_binding_code_channels_return_command_and_persist_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
providers = ["feishu", "dingtalk", "wechat", "wecom"]
|
||||||
|
with TestClient(app) as client:
|
||||||
|
responses = {provider: client.post(f"/api/channels/{provider}/connect") for provider in providers}
|
||||||
|
|
||||||
|
for provider, response in responses.items():
|
||||||
|
expected_display_name = {
|
||||||
|
"feishu": "Feishu",
|
||||||
|
"dingtalk": "DingTalk",
|
||||||
|
"wechat": "WeChat",
|
||||||
|
"wecom": "WeCom",
|
||||||
|
}[provider]
|
||||||
|
assert response.status_code == 200
|
||||||
|
body = response.json()
|
||||||
|
assert body["provider"] == provider
|
||||||
|
assert body["mode"] == "binding_code"
|
||||||
|
assert body["url"] is None
|
||||||
|
assert len(body["code"]) >= 22
|
||||||
|
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow {expected_display_name} bot."
|
||||||
|
|
||||||
|
async def count_states(provider=provider):
|
||||||
|
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider=provider)
|
||||||
|
|
||||||
|
assert anyio.run(count_states) == 1
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_connect_unconfigured_runtime_channel_returns_400(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, {})
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.post("/api/channels/slack/connect")
|
||||||
|
|
||||||
|
assert response.status_code == 400
|
||||||
|
assert "Slack credentials" in response.json()["detail"]
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_provider_runtime_credentials_enables_connect_without_file_edits(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
app = _make_app(config, repo, {})
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
connect_response = client.post("/api/channels/slack/connect")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
configured = configure_response.json()
|
||||||
|
assert configured["provider"] == "slack"
|
||||||
|
assert configured["configured"] is True
|
||||||
|
assert configured["connectable"] is True
|
||||||
|
assert configured["connection_status"] == "connected"
|
||||||
|
assert app.state.channels_config["slack"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
}
|
||||||
|
assert connect_response.status_code == 200
|
||||||
|
assert connect_response.json()["provider"] == "slack"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
|
||||||
|
first_app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
{},
|
||||||
|
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(first_app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
|
||||||
|
restarted_app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||||
|
set_channels_config_state=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(restarted_app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["configured"] is True
|
||||||
|
assert by_provider["slack"]["connectable"] is True
|
||||||
|
assert by_provider["slack"]["connection_status"] == "connected"
|
||||||
|
assert restarted_app.state.channels_config["slack"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
}
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_configure_provider_runtime_credentials_preserves_masked_secrets(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"feishu": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
app = _make_app(
|
||||||
|
config,
|
||||||
|
repo,
|
||||||
|
{
|
||||||
|
"feishu": {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "old-app-id",
|
||||||
|
"app_secret": "old-secret",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
runtime_config_store=runtime_config_store,
|
||||||
|
)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/feishu/runtime-config",
|
||||||
|
json={
|
||||||
|
"values": {
|
||||||
|
"app_id": "new-app-id",
|
||||||
|
"app_secret": "********",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
providers_response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
assert app.state.channels_config["feishu"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "new-app-id",
|
||||||
|
"app_secret": "old-secret",
|
||||||
|
}
|
||||||
|
assert runtime_config_store.get_provider_config("feishu") == {
|
||||||
|
"enabled": True,
|
||||||
|
"app_id": "new-app-id",
|
||||||
|
"app_secret": "old-secret",
|
||||||
|
}
|
||||||
|
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||||
|
assert by_provider["feishu"]["credential_values"] == {
|
||||||
|
"app_id": "new-app-id",
|
||||||
|
"app_secret": "********",
|
||||||
|
}
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||||
|
providers_response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
assert disconnect_response.status_code == 200
|
||||||
|
disconnected = disconnect_response.json()
|
||||||
|
assert disconnected["provider"] == "slack"
|
||||||
|
assert disconnected["configured"] is False
|
||||||
|
assert disconnected["connectable"] is False
|
||||||
|
assert disconnected["connection_status"] == "not_connected"
|
||||||
|
assert runtime_config_store.get_provider_config("slack") is None
|
||||||
|
|
||||||
|
assert providers_response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connection():
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U123",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(seed_connection)
|
||||||
|
config = ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||||
|
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
configure_response = client.post(
|
||||||
|
"/api/channels/slack/runtime-config",
|
||||||
|
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||||
|
)
|
||||||
|
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||||
|
|
||||||
|
assert configure_response.status_code == 200
|
||||||
|
assert disconnect_response.status_code == 200
|
||||||
|
|
||||||
|
async def get_connection_status():
|
||||||
|
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||||
|
|
||||||
|
assert anyio.run(get_connection_status) == "revoked"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connection():
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
return connection["id"]
|
||||||
|
|
||||||
|
connection_id = anyio.run(seed_connection)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 204
|
||||||
|
|
||||||
|
async def get_connection_status():
|
||||||
|
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||||
|
|
||||||
|
assert anyio.run(get_connection_status) == "revoked"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disconnect_connection_is_current_user_scoped(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connection():
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="other-user",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
return connection["id"]
|
||||||
|
|
||||||
|
connection_id = anyio.run(seed_connection)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||||
|
|
||||||
|
assert response.status_code == 404
|
||||||
|
|
||||||
|
async def get_connection_status():
|
||||||
|
return (await repo.list_connections("other-user"))[0]["status"]
|
||||||
|
|
||||||
|
assert anyio.run(get_connection_status) == "connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
@@ -487,6 +487,7 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
|
|||||||
|
|
||||||
# threads.create() returns a Thread-like dict
|
# threads.create() returns a Thread-like dict
|
||||||
mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id})
|
mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id})
|
||||||
|
mock_client.threads.update = AsyncMock(return_value={"thread_id": thread_id})
|
||||||
|
|
||||||
# threads.get() returns thread info (succeeds by default)
|
# threads.get() returns thread info (succeeds by default)
|
||||||
mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id})
|
mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id})
|
||||||
@@ -504,6 +505,17 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
|
|||||||
return mock_client
|
return mock_client
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_channel_connection_repo(tmp_path: Path):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||||
|
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
return ChannelConnectionRepository(
|
||||||
|
get_session_factory(),
|
||||||
|
cipher=ChannelCredentialCipher.from_key("test-channel-key"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _make_stream_part(event: str, data):
|
def _make_stream_part(event: str, data):
|
||||||
return SimpleNamespace(event=event, data=data)
|
return SimpleNamespace(event=event, data=data)
|
||||||
|
|
||||||
@@ -656,16 +668,34 @@ class TestChannelManager:
|
|||||||
|
|
||||||
await manager.start()
|
await manager.start()
|
||||||
|
|
||||||
inbound = InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi")
|
inbound = InboundMessage(
|
||||||
|
channel_name="test",
|
||||||
|
chat_id="chat1",
|
||||||
|
user_id="user1",
|
||||||
|
text="hi",
|
||||||
|
topic_id="topic1",
|
||||||
|
thread_ts="msg1",
|
||||||
|
connection_id="conn1",
|
||||||
|
)
|
||||||
await bus.publish_inbound(inbound)
|
await bus.publish_inbound(inbound)
|
||||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||||
await manager.stop()
|
await manager.stop()
|
||||||
|
|
||||||
# Thread should be created through Gateway
|
# Thread should be created through Gateway
|
||||||
mock_client.threads.create.assert_called_once()
|
mock_client.threads.create.assert_called_once()
|
||||||
|
assert mock_client.threads.create.call_args.kwargs["metadata"] == {
|
||||||
|
"channel_source": {
|
||||||
|
"type": "im_channel",
|
||||||
|
"provider": "test",
|
||||||
|
"chat_id": "chat1",
|
||||||
|
"topic_id": "topic1",
|
||||||
|
"thread_ts": "msg1",
|
||||||
|
"connection_id": "conn1",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
# Thread ID should be stored
|
# Thread ID should be stored
|
||||||
thread_id = store.get_thread_id("test", "chat1")
|
thread_id = store.get_thread_id("test", "chat1", topic_id="topic1")
|
||||||
assert thread_id == "test-thread-123"
|
assert thread_id == "test-thread-123"
|
||||||
|
|
||||||
# runs.wait should be called with the thread_id
|
# runs.wait should be called with the thread_id
|
||||||
@@ -883,10 +913,12 @@ class TestChannelManager:
|
|||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
def test_clarification_follow_up_preserves_history(self):
|
def test_clarification_follow_up_preserves_history(self, monkeypatch):
|
||||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
@@ -1954,10 +1986,12 @@ class TestChannelManager:
|
|||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
def test_same_topic_reuses_thread(self):
|
def test_same_topic_reuses_thread(self, monkeypatch):
|
||||||
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
|
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
@@ -1990,6 +2024,17 @@ class TestChannelManager:
|
|||||||
|
|
||||||
# threads.create should be called only ONCE (second message reuses the thread)
|
# threads.create should be called only ONCE (second message reuses the thread)
|
||||||
mock_client.threads.create.assert_called_once()
|
mock_client.threads.create.assert_called_once()
|
||||||
|
mock_client.threads.update.assert_called_once_with(
|
||||||
|
"topic-thread-1",
|
||||||
|
metadata={
|
||||||
|
"channel_source": {
|
||||||
|
"type": "im_channel",
|
||||||
|
"provider": "test",
|
||||||
|
"chat_id": "chat1",
|
||||||
|
"topic_id": "topic-root-123",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
# Both runs.wait calls should use the same thread_id
|
# Both runs.wait calls should use the same thread_id
|
||||||
assert mock_client.runs.wait.call_count == 2
|
assert mock_client.runs.wait.call_count == 2
|
||||||
@@ -2325,8 +2370,9 @@ class TestResolveRunParamsUserId:
|
|||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
return ChannelManager(bus=bus, store=store)
|
return ChannelManager(bus=bus, store=store)
|
||||||
|
|
||||||
def test_safe_user_id_is_passed_through(self):
|
def test_safe_user_id_is_passed_through(self, monkeypatch):
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
||||||
|
|
||||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
@@ -2334,10 +2380,78 @@ class TestResolveRunParamsUserId:
|
|||||||
assert run_context["user_id"] == "123456"
|
assert run_context["user_id"] == "123456"
|
||||||
assert run_context["channel_user_id"] == "123456"
|
assert run_context["channel_user_id"] == "123456"
|
||||||
|
|
||||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C123",
|
||||||
|
user_id="U-platform",
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
connection_id="connection-1",
|
||||||
|
text="hi",
|
||||||
|
)
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == "deerflow-user-1"
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
|
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
|
from app.channels.manager import _owner_headers
|
||||||
|
|
||||||
|
headers = _owner_headers(msg)
|
||||||
|
assert headers is not None
|
||||||
|
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
msg = InboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C123",
|
||||||
|
user_id="U-platform",
|
||||||
|
owner_user_id="real-user-from-old-binding",
|
||||||
|
text="hi",
|
||||||
|
)
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
|
||||||
|
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
|
||||||
|
from app.channels.manager import _owner_headers
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == "U-platform"
|
||||||
|
assert run_context["channel_user_id"] == "U-platform"
|
||||||
|
assert _owner_headers(msg) is None
|
||||||
|
|
||||||
|
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
|
||||||
from deerflow.config.paths import make_safe_user_id
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
raw = "user@example.com"
|
raw = "user@example.com"
|
||||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||||
|
|
||||||
@@ -2348,8 +2462,9 @@ class TestResolveRunParamsUserId:
|
|||||||
assert run_context["channel_user_id"] == raw
|
assert run_context["channel_user_id"] == raw
|
||||||
|
|
||||||
@pytest.mark.parametrize("raw_user_id", ["", None])
|
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
|
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
|
||||||
manager = self._manager()
|
manager = self._manager()
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
||||||
|
|
||||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
@@ -2358,6 +2473,93 @@ class TestResolveRunParamsUserId:
|
|||||||
assert "channel_user_id" not in run_context
|
assert "channel_user_id" not in run_context
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelManagerConnectionRouting:
|
||||||
|
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
|
||||||
|
from app.channels.manager import ChannelManager
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||||
|
from deerflow.persistence.engine import close_engine
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_channel_connection_repo(tmp_path)
|
||||||
|
alice = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-alice",
|
||||||
|
workspace_id="T1",
|
||||||
|
)
|
||||||
|
bob = await repo.upsert_connection(
|
||||||
|
owner_user_id="bob",
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-bob",
|
||||||
|
workspace_id="T1",
|
||||||
|
)
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
store = ChannelStore(path=tmp_path / "legacy-store.json")
|
||||||
|
manager = ChannelManager(bus=bus, store=store, connection_repo=repo)
|
||||||
|
mock_client = _make_mock_langgraph_client()
|
||||||
|
mock_client.threads.create = AsyncMock(
|
||||||
|
side_effect=[
|
||||||
|
{"thread_id": "thread-alice"},
|
||||||
|
{"thread_id": "thread-bob"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
manager._client = mock_client
|
||||||
|
|
||||||
|
await manager._handle_chat(
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C-shared",
|
||||||
|
user_id="U-alice",
|
||||||
|
owner_user_id="alice",
|
||||||
|
connection_id=alice["id"],
|
||||||
|
text="hello",
|
||||||
|
thread_ts="1710000000.000100",
|
||||||
|
topic_id="1710000000.000100",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await manager._handle_chat(
|
||||||
|
InboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C-shared",
|
||||||
|
user_id="U-bob",
|
||||||
|
owner_user_id="bob",
|
||||||
|
connection_id=bob["id"],
|
||||||
|
text="hello",
|
||||||
|
thread_ts="1710000000.000100",
|
||||||
|
topic_id="1710000000.000100",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||||
|
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||||
|
assert store.list_entries() == []
|
||||||
|
|
||||||
|
first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"]
|
||||||
|
second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"]
|
||||||
|
assert first_context["user_id"] == "alice"
|
||||||
|
assert first_context["channel_user_id"] == "U-alice"
|
||||||
|
assert second_context["user_id"] == "bob"
|
||||||
|
assert second_context["channel_user_id"] == "U-bob"
|
||||||
|
|
||||||
|
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
|
||||||
|
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
|
||||||
|
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||||
|
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||||
|
|
||||||
|
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
|
||||||
|
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
|
||||||
|
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||||
|
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||||
|
|
||||||
|
try:
|
||||||
|
_run(go())
|
||||||
|
finally:
|
||||||
|
_run(close_engine())
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# ChannelService tests
|
# ChannelService tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -3175,6 +3377,144 @@ class TestChannelService:
|
|||||||
|
|
||||||
assert service._config == {"telegram": {"enabled": False}}
|
assert service._config == {"telegram": {"enabled": False}}
|
||||||
|
|
||||||
|
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
app_config = SimpleNamespace(
|
||||||
|
model_extra={},
|
||||||
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
"discord": {"enabled": True},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ChannelService.from_app_config(app_config)
|
||||||
|
|
||||||
|
assert service._config == {}
|
||||||
|
|
||||||
|
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
|
||||||
|
self,
|
||||||
|
monkeypatch,
|
||||||
|
tmp_path,
|
||||||
|
):
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
ChannelRuntimeConfigStore().set_provider_config(
|
||||||
|
"slack",
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app_config = SimpleNamespace(
|
||||||
|
model_extra={
|
||||||
|
"channels": {
|
||||||
|
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||||
|
"slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"},
|
||||||
|
"discord": {"enabled": True, "bot_token": "discord-bot-token"},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
"discord": {"enabled": True},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ChannelService.from_app_config(app_config)
|
||||||
|
|
||||||
|
assert service._config["telegram"]["bot_token"] == "telegram-token"
|
||||||
|
assert service._config["slack"]["app_token"] == "xapp"
|
||||||
|
assert service._config["discord"]["bot_token"] == "discord-bot-token"
|
||||||
|
|
||||||
|
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
|
||||||
|
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
from deerflow.config import paths as paths_module
|
||||||
|
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||||
|
monkeypatch.setattr(paths_module, "_paths", None)
|
||||||
|
ChannelRuntimeConfigStore().set_provider_config(
|
||||||
|
"slack",
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
app_config = SimpleNamespace(
|
||||||
|
model_extra={},
|
||||||
|
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||||
|
{
|
||||||
|
"enabled": True,
|
||||||
|
"slack": {"enabled": True},
|
||||||
|
}
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
service = ChannelService.from_app_config(app_config)
|
||||||
|
|
||||||
|
assert service._config["slack"] == {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_connection_repo_is_forwarded_to_manager(self):
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
repo = object()
|
||||||
|
service = ChannelService(channels_config={}, connection_repo=repo)
|
||||||
|
|
||||||
|
assert service.manager._connection_repo is repo
|
||||||
|
|
||||||
|
def test_remove_channel_stops_running_channel_and_forgets_config(self):
|
||||||
|
from app.channels.service import ChannelService
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
service = ChannelService(
|
||||||
|
channels_config={
|
||||||
|
"slack": {
|
||||||
|
"enabled": True,
|
||||||
|
"bot_token": "xoxb-ui",
|
||||||
|
"app_token": "xapp-ui",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
channel = AsyncMock()
|
||||||
|
service._channels["slack"] = channel
|
||||||
|
service._running = True
|
||||||
|
|
||||||
|
assert await service.remove_channel("slack") is True
|
||||||
|
|
||||||
|
channel.stop.assert_awaited_once()
|
||||||
|
assert "slack" not in service._channels
|
||||||
|
assert "slack" not in service._config
|
||||||
|
|
||||||
|
_run(go())
|
||||||
|
|
||||||
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
||||||
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
||||||
import logging
|
import logging
|
||||||
|
|||||||
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
|
|||||||
mem_cfg.fact_confidence_threshold = 0.7
|
mem_cfg.fact_confidence_threshold = 0.7
|
||||||
mem_cfg.injection_enabled = True
|
mem_cfg.injection_enabled = True
|
||||||
mem_cfg.max_injection_tokens = 2000
|
mem_cfg.max_injection_tokens = 2000
|
||||||
|
mem_cfg.token_counting = "tiktoken"
|
||||||
|
|
||||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
||||||
result = client.get_memory_config()
|
result = client.get_memory_config()
|
||||||
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
|
|||||||
parsed = MemoryConfigResponse(**result)
|
parsed = MemoryConfigResponse(**result)
|
||||||
assert parsed.enabled is True
|
assert parsed.enabled is True
|
||||||
assert parsed.max_facts == 100
|
assert parsed.max_facts == 100
|
||||||
|
assert parsed.token_counting == "tiktoken"
|
||||||
|
|
||||||
def test_get_memory_status(self, client):
|
def test_get_memory_status(self, client):
|
||||||
mem_cfg = MagicMock()
|
mem_cfg = MagicMock()
|
||||||
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
|
|||||||
mem_cfg.fact_confidence_threshold = 0.7
|
mem_cfg.fact_confidence_threshold = 0.7
|
||||||
mem_cfg.injection_enabled = True
|
mem_cfg.injection_enabled = True
|
||||||
mem_cfg.max_injection_tokens = 2000
|
mem_cfg.max_injection_tokens = 2000
|
||||||
|
mem_cfg.token_counting = "tiktoken"
|
||||||
|
|
||||||
memory_data = {
|
memory_data = {
|
||||||
"version": "1.0",
|
"version": "1.0",
|
||||||
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
|
|||||||
|
|
||||||
parsed = MemoryStatusResponse(**result)
|
parsed = MemoryStatusResponse(**result)
|
||||||
assert parsed.config.enabled is True
|
assert parsed.config.enabled is True
|
||||||
|
assert parsed.config.token_counting == "tiktoken"
|
||||||
assert parsed.data.version == "1.0"
|
assert parsed.data.version == "1.0"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
"""Regression test for the Docker Compose default Gateway worker count.
|
||||||
|
|
||||||
|
The Gateway holds run state (RunManager and the stream bridge) in process, so
|
||||||
|
the default deployment must run a single Uvicorn worker. Running more than one
|
||||||
|
worker without a shared cross-worker stream bridge breaks run cancellation, SSE
|
||||||
|
reconnects, request de-duplication, and IM channels (nginx has no sticky
|
||||||
|
sessions, so requests scatter across workers that each keep their own run
|
||||||
|
state). This test pins the safe default so it cannot silently regress to a
|
||||||
|
multi-worker default, while still allowing operators to override it once a
|
||||||
|
shared stream bridge exists.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||||
|
COMPOSE_PATH = REPO_ROOT / "docker" / "docker-compose.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
def _gateway_command() -> str:
|
||||||
|
"""Return the gateway service command as a single string."""
|
||||||
|
compose = yaml.safe_load(COMPOSE_PATH.read_text(encoding="utf-8"))
|
||||||
|
command = compose["services"]["gateway"]["command"]
|
||||||
|
# ``command`` may load as a scalar string or a list depending on YAML style.
|
||||||
|
if isinstance(command, list):
|
||||||
|
command = " ".join(str(part) for part in command)
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_defaults_to_single_worker():
|
||||||
|
"""With GATEWAY_WORKERS unset, the worker count must default to 1."""
|
||||||
|
command = _gateway_command()
|
||||||
|
match = re.search(r"GATEWAY_WORKERS:-(\d+)", command)
|
||||||
|
assert match is not None, f"gateway command must set a GATEWAY_WORKERS default; got: {command}"
|
||||||
|
assert match.group(1) == "1", f"default Gateway worker count must be 1, got {match.group(1)}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_gateway_worker_count_remains_overridable():
|
||||||
|
"""The worker count must stay configurable, not hard-coded to 1."""
|
||||||
|
command = _gateway_command()
|
||||||
|
assert "${GATEWAY_WORKERS:-1}" in command, f"worker count must use ${{GATEWAY_WORKERS:-1}} so operators can override it; got: {command}"
|
||||||
@@ -233,3 +233,15 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
|||||||
|
|
||||||
assert response.status_code == 403
|
assert response.status_code == 403
|
||||||
assert response.json()["detail"] == "CSRF token mismatch."
|
assert response.json()["detail"] == "CSRF token mismatch."
|
||||||
|
|
||||||
|
|
||||||
|
def test_channel_posts_require_double_submit_csrf():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/channels/slack/connect",
|
||||||
|
headers={"Origin": "https://deerflow.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||||
|
|||||||
@@ -203,6 +203,79 @@ class TestLoadAgentConfig:
|
|||||||
assert cfg.name == "legacy-agent"
|
assert cfg.name == "legacy-agent"
|
||||||
|
|
||||||
|
|
||||||
|
# ===========================================================================
|
||||||
|
# 3b. resolve_agent_dir — memory-only directory fallback (#3390)
|
||||||
|
# ===========================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveAgentDirMemoryOnlyFallback:
|
||||||
|
"""Regression tests for #3390.
|
||||||
|
|
||||||
|
When memory is enabled, the first conversation creates a user-isolated
|
||||||
|
agent directory containing only ``memory.json`` (no ``config.yaml``).
|
||||||
|
On the next turn ``resolve_agent_dir`` must fall through to the legacy
|
||||||
|
shared layout instead of returning the incomplete user directory.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path):
|
||||||
|
"""User dir has memory.json but no config.yaml → use legacy dir."""
|
||||||
|
from deerflow.config.agents_config import resolve_agent_dir
|
||||||
|
|
||||||
|
# Legacy agent with full config
|
||||||
|
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||||
|
legacy_dir.mkdir(parents=True)
|
||||||
|
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||||
|
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
|
||||||
|
|
||||||
|
# User dir created by memory write — no config.yaml
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert result == legacy_dir
|
||||||
|
|
||||||
|
def test_user_dir_with_config_takes_priority(self, tmp_path):
|
||||||
|
"""User dir with config.yaml should still win over legacy."""
|
||||||
|
from deerflow.config.agents_config import resolve_agent_dir
|
||||||
|
|
||||||
|
# Legacy
|
||||||
|
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||||
|
legacy_dir.mkdir(parents=True)
|
||||||
|
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||||
|
|
||||||
|
# User dir with full config (migrated)
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8")
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert result == user_dir
|
||||||
|
|
||||||
|
def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path):
|
||||||
|
"""End-to-end: load_agent_config works when user dir only has memory.json."""
|
||||||
|
config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"}
|
||||||
|
_write_agent(tmp_path, "my-agent", config_dict)
|
||||||
|
|
||||||
|
# Simulate memory write creating user dir without config
|
||||||
|
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||||
|
user_dir.mkdir(parents=True)
|
||||||
|
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||||
|
|
||||||
|
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||||
|
from deerflow.config.agents_config import load_agent_config
|
||||||
|
|
||||||
|
cfg = load_agent_config("my-agent", user_id="u1")
|
||||||
|
|
||||||
|
assert cfg.name == "my-agent"
|
||||||
|
assert cfg.model == "deepseek-v3"
|
||||||
|
|
||||||
|
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
# 4. load_agent_soul
|
# 4. load_agent_soul
|
||||||
# ===========================================================================
|
# ===========================================================================
|
||||||
|
|||||||
@@ -0,0 +1,88 @@
|
|||||||
|
"""Discord connection routing tests."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.channels.discord import DiscordChannel
|
||||||
|
from app.channels.message_bus import InboundMessage, MessageBus
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def repo(tmp_path):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
try:
|
||||||
|
yield ChannelConnectionRepository(
|
||||||
|
get_session_factory(),
|
||||||
|
cipher=ChannelCredentialCipher.from_key("discord-secret"),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo):
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="alice",
|
||||||
|
provider="discord",
|
||||||
|
external_account_id="987",
|
||||||
|
external_account_name="Alice",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
channel = DiscordChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
inbound = InboundMessage(
|
||||||
|
channel_name="discord",
|
||||||
|
chat_id="C123",
|
||||||
|
user_id="987",
|
||||||
|
text="hello",
|
||||||
|
)
|
||||||
|
|
||||||
|
attached = await channel._attach_connection_identity(inbound, guild_id="G123")
|
||||||
|
|
||||||
|
assert attached.connection_id == connection["id"]
|
||||||
|
assert attached.owner_user_id == "alice"
|
||||||
|
assert attached.workspace_id is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_discord_connect_command_binds_gateway_identity(repo):
|
||||||
|
state = "discord-bind-code"
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="discord",
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
channel = DiscordChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
message = MagicMock()
|
||||||
|
message.author.id = 987
|
||||||
|
message.author.display_name = "Alice"
|
||||||
|
message.guild.id = 123
|
||||||
|
message.guild.name = "Deer Guild"
|
||||||
|
message.channel.id = 456
|
||||||
|
message.channel.send = AsyncMock()
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(message, state)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "discord"
|
||||||
|
assert connections[0]["external_account_id"] == "987"
|
||||||
|
assert connections[0]["external_account_name"] == "Alice"
|
||||||
|
assert connections[0]["workspace_id"] == "123"
|
||||||
|
assert connections[0]["workspace_name"] == "Deer Guild"
|
||||||
|
assert connections[0]["metadata"]["channel_id"] == "456"
|
||||||
|
message.channel.send.assert_awaited_once()
|
||||||
@@ -73,6 +73,31 @@ def test_feishu_on_message_plain_text():
|
|||||||
assert mock_make_inbound.call_args[1]["text"] == "Hello world"
|
assert mock_make_inbound.call_args[1]["text"] == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_feishu_is_not_running_when_ws_thread_exits():
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
|
||||||
|
channel._running = True
|
||||||
|
channel._thread = MagicMock()
|
||||||
|
channel._thread.is_alive.return_value = False
|
||||||
|
|
||||||
|
assert channel.is_running is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_feishu_event_handler_ignores_non_content_message_events():
|
||||||
|
import lark_oapi as lark
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
|
||||||
|
|
||||||
|
event_handler = channel._build_event_handler(lark)
|
||||||
|
|
||||||
|
assert "p2.im.message.receive_v1" in event_handler._processorMap
|
||||||
|
assert "p2.im.message.message_read_v1" in event_handler._processorMap
|
||||||
|
assert "p2.im.message.reaction.created_v1" in event_handler._processorMap
|
||||||
|
assert "p2.im.message.reaction.deleted_v1" in event_handler._processorMap
|
||||||
|
assert "p2.im.message.recalled_v1" in event_handler._processorMap
|
||||||
|
|
||||||
|
|
||||||
def test_feishu_on_message_rich_text():
|
def test_feishu_on_message_rich_text():
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
config = {"app_id": "test", "app_secret": "test"}
|
config = {"app_id": "test", "app_secret": "test"}
|
||||||
|
|||||||
@@ -474,6 +474,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
|
|||||||
assert config["context"]["user_id"] == "channel-user-7"
|
assert config["context"]["user_id"] == "channel-user-7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_run_uses_internal_owner_header_for_persistence():
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||||
|
from app.gateway.services import start_run
|
||||||
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
|
from deerflow.runtime import RunManager
|
||||||
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
|
async def _scenario():
|
||||||
|
run_store = MemoryRunStore()
|
||||||
|
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||||
|
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
|
||||||
|
run_manager = RunManager(store=run_store)
|
||||||
|
state = SimpleNamespace(
|
||||||
|
stream_bridge=SimpleNamespace(),
|
||||||
|
run_manager=run_manager,
|
||||||
|
checkpointer=InMemorySaver(),
|
||||||
|
store=InMemoryStore(),
|
||||||
|
run_event_store=SimpleNamespace(),
|
||||||
|
run_events_config=None,
|
||||||
|
thread_store=thread_store,
|
||||||
|
)
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||||
|
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||||
|
app=SimpleNamespace(state=state),
|
||||||
|
)
|
||||||
|
body = SimpleNamespace(
|
||||||
|
assistant_id="lead_agent",
|
||||||
|
input={"messages": [{"role": "human", "content": "hi"}]},
|
||||||
|
metadata={},
|
||||||
|
config=None,
|
||||||
|
context=None,
|
||||||
|
on_disconnect="cancel",
|
||||||
|
multitask_strategy="reject",
|
||||||
|
stream_mode=None,
|
||||||
|
stream_subgraphs=False,
|
||||||
|
interrupt_before=None,
|
||||||
|
interrupt_after=None,
|
||||||
|
)
|
||||||
|
task_context: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def fake_run_agent(*args, **kwargs):
|
||||||
|
task_context["user_id"] = get_effective_user_id()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
|
||||||
|
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
|
||||||
|
):
|
||||||
|
record = await start_run(body, "channel-thread", request)
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
owner_run = await run_store.get(record.run_id, user_id="owner-1")
|
||||||
|
default_run = await run_store.get(record.run_id, user_id="default")
|
||||||
|
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
|
||||||
|
default_thread = await thread_store.get("channel-thread", user_id="default")
|
||||||
|
return owner_run, default_run, owner_thread, default_thread, task_context
|
||||||
|
|
||||||
|
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
|
||||||
|
|
||||||
|
assert owner_run is not None
|
||||||
|
assert owner_run["user_id"] == "owner-1"
|
||||||
|
assert default_run is None
|
||||||
|
assert owner_thread is not None
|
||||||
|
assert owner_thread["user_id"] == "owner-1"
|
||||||
|
assert owner_thread["metadata"] == {"legacy": True}
|
||||||
|
assert default_thread is None
|
||||||
|
assert task_context["user_id"] == "owner-1"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
|
|||||||
assert reloaded.is_valid_internal_auth_token(token) is True
|
assert reloaded.is_valid_internal_auth_token(token) is True
|
||||||
finally:
|
finally:
|
||||||
importlib.reload(reloaded)
|
importlib.reload(reloaded)
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
|
||||||
|
import app.gateway.internal_auth as internal_auth
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
|
||||||
|
reloaded = importlib.reload(internal_auth)
|
||||||
|
try:
|
||||||
|
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
|
||||||
|
|
||||||
|
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
|
||||||
|
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
|
||||||
|
finally:
|
||||||
|
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
|
||||||
|
importlib.reload(reloaded)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from langgraph_sdk import Auth
|
|||||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||||
from app.gateway.auth.models import User
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
@@ -59,6 +60,14 @@ def test_no_cookie_raises_401():
|
|||||||
assert "Not authenticated" in str(exc.value.detail)
|
assert "Not authenticated" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_disabled_skips_csrf_and_authenticates_e2e_user(monkeypatch):
|
||||||
|
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||||
|
|
||||||
|
identity = asyncio.run(authenticate(_req(method="POST")))
|
||||||
|
|
||||||
|
assert identity == AUTH_DISABLED_USER_ID
|
||||||
|
|
||||||
|
|
||||||
def test_invalid_jwt_raises_401():
|
def test_invalid_jwt_raises_401():
|
||||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||||
|
|||||||
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
|
|||||||
|
|
||||||
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
||||||
explicit_config = SimpleNamespace(
|
explicit_config = SimpleNamespace(
|
||||||
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
|
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
|
||||||
)
|
)
|
||||||
captured: dict[str, object] = {}
|
captured: dict[str, object] = {}
|
||||||
|
|
||||||
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
|||||||
captured["user_id"] = user_id
|
captured["user_id"] = user_id
|
||||||
return {"facts": []}
|
return {"facts": []}
|
||||||
|
|
||||||
def fake_format_memory_for_injection(memory_data, *, max_tokens):
|
def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
|
||||||
captured["memory_data"] = memory_data
|
captured["memory_data"] = memory_data
|
||||||
captured["max_tokens"] = max_tokens
|
captured["max_tokens"] = max_tokens
|
||||||
|
captured["use_tiktoken"] = use_tiktoken
|
||||||
return "remember this"
|
return "remember this"
|
||||||
|
|
||||||
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
||||||
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
|||||||
"user_id": "user-1",
|
"user_id": "user-1",
|
||||||
"memory_data": {"facts": []},
|
"memory_data": {"facts": []},
|
||||||
"max_tokens": 1234,
|
"max_tokens": 1234,
|
||||||
|
"use_tiktoken": True,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -612,6 +612,54 @@ class TestLocalSandboxProviderMounts:
|
|||||||
|
|
||||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||||
|
|
||||||
|
def test_setup_path_mappings_logs_actionable_error_for_missing_host_path(self, tmp_path, caplog):
|
||||||
|
"""Regression for #3244.
|
||||||
|
|
||||||
|
When ``sandbox.mounts[].host_path`` is absent from the gateway process's
|
||||||
|
filesystem (the typical symptom in Docker production mode: host_path is a
|
||||||
|
host machine path that is not bind-mounted into the gateway container),
|
||||||
|
the mount is still skipped — but the failure must be a hard-to-miss ERROR
|
||||||
|
log with explicit, actionable guidance about Docker bind mounts, not the
|
||||||
|
old DEBUG/WARNING that buried the silent failure.
|
||||||
|
"""
|
||||||
|
skills_dir = tmp_path / "skills"
|
||||||
|
skills_dir.mkdir()
|
||||||
|
missing_host_path = tmp_path / "does-not-exist"
|
||||||
|
|
||||||
|
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||||
|
|
||||||
|
sandbox_config = SandboxConfig(
|
||||||
|
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||||
|
mounts=[
|
||||||
|
VolumeMountConfig(host_path=str(missing_host_path), container_path="/mnt/knowledge", read_only=True),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
config = SimpleNamespace(
|
||||||
|
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir, use="deerflow.skills.storage.local_skill_storage:LocalSkillStorage"),
|
||||||
|
sandbox=sandbox_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
with caplog.at_level("ERROR", logger="deerflow.sandbox.local.local_sandbox_provider"):
|
||||||
|
with patch("deerflow.config.get_app_config", return_value=config):
|
||||||
|
provider = LocalSandboxProvider()
|
||||||
|
|
||||||
|
# Silent-skip behaviour is preserved (no breaking change for existing deployments).
|
||||||
|
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||||
|
|
||||||
|
# The failure must be observable at ERROR level and reference the offending paths.
|
||||||
|
error_records = [r for r in caplog.records if r.levelname == "ERROR"]
|
||||||
|
assert error_records, "expected an ERROR log when host_path is missing"
|
||||||
|
message = "\n".join(r.getMessage() for r in error_records)
|
||||||
|
assert str(missing_host_path) in message
|
||||||
|
assert "/mnt/knowledge" in message
|
||||||
|
|
||||||
|
# And it must include actionable Docker guidance so users don't lose hours
|
||||||
|
# to a silent empty-mount failure in production.
|
||||||
|
lowered = message.lower()
|
||||||
|
assert "docker" in lowered
|
||||||
|
assert "gateway" in lowered
|
||||||
|
assert "docker-compose" in lowered
|
||||||
|
|
||||||
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
|
def test_write_file_resolves_container_paths_in_content(self, tmp_path):
|
||||||
"""write_file should replace container paths in file content with local paths."""
|
"""write_file should replace container paths in file content with local paths."""
|
||||||
data_dir = tmp_path / "data"
|
data_dir = tmp_path / "data"
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
|
|||||||
|
|
||||||
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
||||||
# Make token counting deterministic for this test by counting characters.
|
# Make token counting deterministic for this test by counting characters.
|
||||||
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
|
||||||
|
|
||||||
memory_data = {
|
memory_data = {
|
||||||
"user": {},
|
"user": {},
|
||||||
|
|||||||
@@ -179,15 +179,16 @@ class TestLifecycleCallbacks:
|
|||||||
assert "run.end" in types
|
assert "run.end" in types
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_nested_chain_no_run_start(self, journal_setup):
|
async def test_nested_chain_no_run_lifecycle_events(self, journal_setup):
|
||||||
"""Nested chains (parent_run_id set) should NOT produce run.start."""
|
"""Nested chains (parent_run_id set) should NOT produce root run lifecycle events."""
|
||||||
j, store = journal_setup
|
j, store = journal_setup
|
||||||
parent_id = uuid4()
|
parent_id = uuid4()
|
||||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||||
j.on_chain_end({}, run_id=uuid4())
|
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
|
||||||
await j.flush()
|
await j.flush()
|
||||||
events = await store.list_events("t1", "r1")
|
events = await store.list_events("t1", "r1")
|
||||||
assert not any(e["event_type"] == "run.start" for e in events)
|
assert not any(e["event_type"] == "run.start" for e in events)
|
||||||
|
assert not any(e["event_type"] == "run.end" for e in events)
|
||||||
|
|
||||||
|
|
||||||
class TestToolCallbacks:
|
class TestToolCallbacks:
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ Run from repo root:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from wizard import ui as wizard_ui
|
||||||
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
|
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
|
||||||
|
from wizard.steps import channels as channels_step
|
||||||
from wizard.steps import llm as llm_step
|
from wizard.steps import llm as llm_step
|
||||||
from wizard.steps import search as search_step
|
from wizard.steps import search as search_step
|
||||||
from wizard.writer import (
|
from wizard.writer import (
|
||||||
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
|
|||||||
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
|
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
|
||||||
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
|
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
|
||||||
|
|
||||||
|
def test_can_enable_selected_channel_connections(self):
|
||||||
|
content = build_minimal_config(
|
||||||
|
provider_use="langchain_openai:ChatOpenAI",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
display_name="OpenAI",
|
||||||
|
api_key_field="api_key",
|
||||||
|
env_var="OPENAI_API_KEY",
|
||||||
|
channel_connection_providers=["feishu", "slack"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
channel_connections = data["channel_connections"]
|
||||||
|
|
||||||
|
assert channel_connections["enabled"] is True
|
||||||
|
assert channel_connections["feishu"]["enabled"] is True
|
||||||
|
assert channel_connections["slack"]["enabled"] is True
|
||||||
|
assert channel_connections["telegram"]["enabled"] is False
|
||||||
|
assert channel_connections["discord"]["enabled"] is False
|
||||||
|
assert channel_connections["dingtalk"]["enabled"] is False
|
||||||
|
assert channel_connections["wechat"]["enabled"] is False
|
||||||
|
assert channel_connections["wecom"]["enabled"] is False
|
||||||
|
|
||||||
|
def test_channel_connections_disabled_when_no_channels_selected(self):
|
||||||
|
content = build_minimal_config(
|
||||||
|
provider_use="langchain_openai:ChatOpenAI",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
display_name="OpenAI",
|
||||||
|
api_key_field="api_key",
|
||||||
|
env_var="OPENAI_API_KEY",
|
||||||
|
channel_connection_providers=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
channel_connections = data["channel_connections"]
|
||||||
|
|
||||||
|
assert channel_connections["enabled"] is False
|
||||||
|
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
|
||||||
|
|
||||||
|
|
||||||
class TestLLMStep:
|
class TestLLMStep:
|
||||||
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
|
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
|
||||||
@@ -384,6 +424,41 @@ class TestLLMStep:
|
|||||||
assert result.base_url == "https://gateway.example/v1"
|
assert result.base_url == "https://gateway.example/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelsStep:
|
||||||
|
def test_returns_selected_channel_keys(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
|
||||||
|
|
||||||
|
result = channels_step.run_channels_step()
|
||||||
|
|
||||||
|
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
|
||||||
|
|
||||||
|
def test_empty_selection_disables_channel_connections(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
|
||||||
|
|
||||||
|
result = channels_step.run_channels_step()
|
||||||
|
|
||||||
|
assert result.enabled_providers == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestWizardUi:
|
||||||
|
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
|
||||||
|
answers = iter(["", "2"])
|
||||||
|
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
|
||||||
|
|
||||||
|
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
|
||||||
|
|
||||||
|
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("builtins.input", lambda _prompt: "")
|
||||||
|
|
||||||
|
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# writer.py — env file helpers
|
# writer.py — env file helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -0,0 +1,154 @@
|
|||||||
|
"""Slack connection tests for user-owned channel bindings."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from types import ModuleType
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
from app.channels.message_bus import MessageBus, OutboundMessage
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_repo(tmp_path):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||||
|
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
return ChannelConnectionRepository(
|
||||||
|
get_session_factory(),
|
||||||
|
cipher=ChannelCredentialCipher.from_key("slack-secret"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_connect_command_binds_socket_mode_identity(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = await _make_repo(tmp_path)
|
||||||
|
state = "slack-bind-code"
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="slack",
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
channel = SlackChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._web_client = MagicMock()
|
||||||
|
|
||||||
|
handled = await channel._bind_connection_from_connect_code(
|
||||||
|
event={
|
||||||
|
"user": "U123",
|
||||||
|
"channel": "C123",
|
||||||
|
"ts": "1710000000.000100",
|
||||||
|
},
|
||||||
|
team_id="T123",
|
||||||
|
code=state,
|
||||||
|
)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert handled is True
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "slack"
|
||||||
|
assert connections[0]["external_account_id"] == "U123"
|
||||||
|
assert connections[0]["workspace_id"] == "T123"
|
||||||
|
assert connections[0]["metadata"]["channel_id"] == "C123"
|
||||||
|
channel._web_client.chat_postMessage.assert_called_once()
|
||||||
|
await repo.close()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
repo = AsyncMock()
|
||||||
|
repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"}
|
||||||
|
web_client = MagicMock()
|
||||||
|
web_client_factory = MagicMock(return_value=web_client)
|
||||||
|
channel = SlackChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={
|
||||||
|
"connection_repo": repo,
|
||||||
|
"web_client_factory": web_client_factory,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
msg = OutboundMessage(
|
||||||
|
channel_name="slack",
|
||||||
|
chat_id="C123",
|
||||||
|
thread_id="thread-1",
|
||||||
|
text="hello",
|
||||||
|
connection_id="connection-1",
|
||||||
|
)
|
||||||
|
await channel.send(msg)
|
||||||
|
|
||||||
|
repo.get_credentials.assert_awaited_once_with("connection-1")
|
||||||
|
web_client_factory.assert_called_once_with(token="xoxb-connection-token")
|
||||||
|
web_client.chat_postMessage.assert_called_once()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
class FakeWebClient:
|
||||||
|
def __init__(self, token: str) -> None:
|
||||||
|
self.token = token
|
||||||
|
self.messages: list[dict] = []
|
||||||
|
|
||||||
|
def auth_test(self):
|
||||||
|
return {"user_id": "B-http"}
|
||||||
|
|
||||||
|
def chat_postMessage(self, **kwargs):
|
||||||
|
self.messages.append(kwargs)
|
||||||
|
|
||||||
|
slack_sdk = ModuleType("slack_sdk")
|
||||||
|
slack_sdk.WebClient = FakeWebClient
|
||||||
|
socket_mode = ModuleType("slack_sdk.socket_mode")
|
||||||
|
socket_mode.SocketModeClient = object
|
||||||
|
response = ModuleType("slack_sdk.socket_mode.response")
|
||||||
|
response.SocketModeResponse = object
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
channel = SlackChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={
|
||||||
|
"bot_token": "xoxb-operator",
|
||||||
|
"event_delivery": "http",
|
||||||
|
"connection_repo": MagicMock(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
assert channel._running is True
|
||||||
|
assert channel._web_client is not None
|
||||||
|
assert channel._web_client.token == "xoxb-operator"
|
||||||
|
assert channel._bot_user_id == "B-http"
|
||||||
|
|
||||||
|
channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
|
||||||
|
|
||||||
|
assert channel._web_client.messages == [
|
||||||
|
{
|
||||||
|
"channel": "C123",
|
||||||
|
"text": "Slack connected to DeerFlow.",
|
||||||
|
"thread_ts": "1710000000.000100",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await channel.stop()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
@@ -0,0 +1,173 @@
|
|||||||
|
"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints.
|
||||||
|
|
||||||
|
These endpoints receive ``thread_id`` in the request body, so the
|
||||||
|
``@require_permission(owner_check=True)`` decorator — which reads the
|
||||||
|
``thread_id`` *path* parameter — cannot protect them. The owner check
|
||||||
|
lives inside ``services.start_run()`` instead; this suite pins it at the
|
||||||
|
HTTP layer so the gap cannot silently reopen.
|
||||||
|
|
||||||
|
Strategy
|
||||||
|
--------
|
||||||
|
``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a
|
||||||
|
request that *passes* the owner check deterministically short-circuits
|
||||||
|
with 409 before any agent code runs. The two outcomes:
|
||||||
|
|
||||||
|
- 404 + ``create_or_reject`` never awaited -> blocked by the owner check
|
||||||
|
- 409 + ``create_or_reject`` awaited -> passed the owner check
|
||||||
|
|
||||||
|
The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the
|
||||||
|
``check_access`` semantics under test — missing row allows, ``user_id``
|
||||||
|
NULL allows, foreign owner denies — are exercised through real code.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from _router_auth_helpers import make_authed_test_app
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.routers import runs
|
||||||
|
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||||
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
|
from deerflow.runtime import ConflictError
|
||||||
|
|
||||||
|
USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||||
|
USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||||
|
INTERNAL_USER = SimpleNamespace(id="default", system_role="internal")
|
||||||
|
|
||||||
|
THREAD_A = "thread-owned-by-a"
|
||||||
|
THREAD_SHARED = "thread-shared-null-owner"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _stub_app_config():
|
||||||
|
"""Inject a minimal AppConfig so the allowed path (which builds a
|
||||||
|
RunContext via ``get_config()``) never reads config.yaml from disk."""
|
||||||
|
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||||
|
yield
|
||||||
|
reset_app_config()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_thread_store() -> MemoryThreadMetaStore:
|
||||||
|
store = MemoryThreadMetaStore(InMemoryStore())
|
||||||
|
|
||||||
|
async def _seed():
|
||||||
|
await store.create(THREAD_A, user_id=str(USER_A.id))
|
||||||
|
await store.create(THREAD_SHARED, user_id=None)
|
||||||
|
|
||||||
|
asyncio.run(_seed())
|
||||||
|
return store
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _client(user):
|
||||||
|
"""Yield a ``TestClient`` authenticated as ``user`` plus the stubbed
|
||||||
|
``create_or_reject`` mock, closing the client (and its anyio portal /
|
||||||
|
background threads) on exit.
|
||||||
|
|
||||||
|
``create_or_reject`` raises ``ConflictError`` so a request that passes the
|
||||||
|
owner check short-circuits to 409 before any agent code runs.
|
||||||
|
"""
|
||||||
|
app = make_authed_test_app(user_factory=lambda: user)
|
||||||
|
app.include_router(runs.router)
|
||||||
|
app.state.thread_store = _make_thread_store()
|
||||||
|
app.state.stream_bridge = MagicMock()
|
||||||
|
app.state.checkpointer = MagicMock()
|
||||||
|
app.state.store = MagicMock()
|
||||||
|
app.state.run_events_config = None
|
||||||
|
app.state.run_event_store = MagicMock()
|
||||||
|
run_manager = MagicMock()
|
||||||
|
run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed"))
|
||||||
|
app.state.run_manager = run_manager
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client, run_manager.create_or_reject
|
||||||
|
|
||||||
|
|
||||||
|
def _body(thread_id: str | None = None) -> dict:
|
||||||
|
if thread_id is None:
|
||||||
|
return {}
|
||||||
|
return {"config": {"configurable": {"thread_id": thread_id}}}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Denied: another user's thread
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_cross_user_returns_404():
|
||||||
|
"""User B cannot start a run on user A's thread via /api/runs/stream."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json()["detail"] == f"Thread {THREAD_A} not found"
|
||||||
|
create_or_reject.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_cross_user_returns_404_without_channel_values():
|
||||||
|
"""User B cannot read user A's checkpoint state via /api/runs/wait."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 404
|
||||||
|
assert response.json() == {"detail": f"Thread {THREAD_A} not found"}
|
||||||
|
create_or_reject.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Allowed: owner, fresh/untracked/shared threads, internal role
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_owner_passes_owner_check():
|
||||||
|
"""User A reaches run creation on their own thread (409 sentinel)."""
|
||||||
|
with _client(USER_A) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_wait_owner_passes_owner_check():
|
||||||
|
with _client(USER_A) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_without_thread_id_passes_owner_check():
|
||||||
|
"""Stateless run with no thread_id auto-creates a thread — never blocked."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body())
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_untracked_thread_passes_owner_check():
|
||||||
|
"""A thread_id with no thread_meta row (untracked legacy) stays accessible."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body("never-created-thread"))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_shared_thread_passes_owner_check():
|
||||||
|
"""A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible."""
|
||||||
|
with _client(USER_B) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_internal_role_bypasses_owner_check():
|
||||||
|
"""IM channels run with the internal system role on behalf of platform
|
||||||
|
users whose threads they do not own — the owner check must not break them."""
|
||||||
|
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||||
|
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||||
|
assert response.status_code == 409
|
||||||
|
create_or_reject.assert_awaited()
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
"""Tests for Telegram deep-link channel connections."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.channels.message_bus import MessageBus
|
||||||
|
from app.channels.telegram import TelegramChannel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def repo(tmp_path: Path):
|
||||||
|
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
|
||||||
|
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path))
|
||||||
|
try:
|
||||||
|
yield ChannelConnectionRepository(
|
||||||
|
get_session_factory(),
|
||||||
|
cipher=ChannelCredentialCipher.from_key("telegram-secret"),
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
|
def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"):
|
||||||
|
update = MagicMock()
|
||||||
|
update.effective_user.id = user_id
|
||||||
|
update.effective_user.username = "alice"
|
||||||
|
update.effective_user.full_name = "Alice Example"
|
||||||
|
update.effective_chat.id = chat_id
|
||||||
|
update.effective_chat.type = chat_type
|
||||||
|
update.message.text = text
|
||||||
|
update.message.message_id = 55
|
||||||
|
update.message.reply_to_message = None
|
||||||
|
update.message.reply_text = AsyncMock()
|
||||||
|
return update
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_start_with_deep_link_state_binds_telegram_chat(repo):
|
||||||
|
state = "telegram-bind-state"
|
||||||
|
await repo.create_oauth_state(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="telegram",
|
||||||
|
state=state,
|
||||||
|
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||||
|
)
|
||||||
|
channel = TelegramChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={"bot_token": "test-token", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
update = _telegram_update(text=f"/start {state}")
|
||||||
|
context = MagicMock()
|
||||||
|
context.args = [state]
|
||||||
|
|
||||||
|
await channel._cmd_start(update, context)
|
||||||
|
|
||||||
|
connections = await repo.list_connections("deerflow-user-1")
|
||||||
|
assert len(connections) == 1
|
||||||
|
assert connections[0]["provider"] == "telegram"
|
||||||
|
assert connections[0]["external_account_id"] == "42"
|
||||||
|
assert connections[0]["external_account_name"] == "Alice Example"
|
||||||
|
assert connections[0]["workspace_id"] == "100"
|
||||||
|
assert connections[0]["metadata"]["chat_type"] == "private"
|
||||||
|
update.message.reply_text.assert_awaited_once()
|
||||||
|
assert "connected" in update.message.reply_text.await_args.args[0].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_bound_telegram_message_publishes_connection_identity(repo):
|
||||||
|
connection = await repo.upsert_connection(
|
||||||
|
owner_user_id="deerflow-user-1",
|
||||||
|
provider="telegram",
|
||||||
|
external_account_id="42",
|
||||||
|
external_account_name="Alice Example",
|
||||||
|
workspace_id="100",
|
||||||
|
metadata={"chat_type": "private"},
|
||||||
|
)
|
||||||
|
bus = MessageBus()
|
||||||
|
channel = TelegramChannel(
|
||||||
|
bus=bus,
|
||||||
|
config={"bot_token": "test-token", "connection_repo": repo},
|
||||||
|
)
|
||||||
|
channel._main_loop = __import__("asyncio").get_event_loop()
|
||||||
|
channel._send_running_reply = AsyncMock()
|
||||||
|
|
||||||
|
await channel._on_text(_telegram_update(text="hello"), None)
|
||||||
|
inbound = await bus.get_inbound()
|
||||||
|
|
||||||
|
assert inbound.connection_id == connection["id"]
|
||||||
|
assert inbound.owner_user_id == "deerflow-user-1"
|
||||||
|
assert inbound.workspace_id == "100"
|
||||||
|
assert inbound.user_id == "42"
|
||||||
|
assert inbound.chat_id == "100"
|
||||||
|
assert inbound.text == "hello"
|
||||||
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
|
|||||||
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
||||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_owner_with_bypass_moves_row(self, repo):
|
||||||
|
await repo.create("t1", user_id="default", metadata={"source": "channel"})
|
||||||
|
await repo.update_owner("t1", "owner-1", user_id=None)
|
||||||
|
|
||||||
|
owner_row = await repo.get("t1", user_id="owner-1")
|
||||||
|
default_row = await repo.get("t1", user_id="default")
|
||||||
|
|
||||||
|
assert owner_row is not None
|
||||||
|
assert owner_row["user_id"] == "owner-1"
|
||||||
|
assert owner_row["metadata"] == {"source": "channel"}
|
||||||
|
assert default_row is None
|
||||||
|
|
||||||
# --- search with metadata filter (SQL push-down) ---
|
# --- search with metadata filter (SQL push-down) ---
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
|
|||||||
assert body["created_at"] == body["updated_at"]
|
assert body["created_at"] == body["updated_at"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_owner_header_assigns_thread_to_owner() -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||||
|
|
||||||
|
store = InMemoryStore()
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
thread_store = MemoryThreadMetaStore(store)
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||||
|
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||||
|
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _scenario():
|
||||||
|
response = await threads.create_thread(
|
||||||
|
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
|
||||||
|
internal_row = await thread_store.get("channel-thread", user_id="default")
|
||||||
|
return response, owner_row, internal_row
|
||||||
|
|
||||||
|
response, owner_row, internal_row = asyncio.run(_scenario())
|
||||||
|
|
||||||
|
assert response.thread_id == "channel-thread"
|
||||||
|
assert owner_row is not None
|
||||||
|
assert owner_row["user_id"] == "owner-1"
|
||||||
|
assert internal_row is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||||
"""A thread record written by older versions stores ``time.time()``
|
"""A thread record written by older versions stores ``time.time()``
|
||||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||||
|
|||||||
@@ -5,18 +5,22 @@ Verifies:
|
|||||||
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
||||||
unavailable or the encoding fails to load.
|
unavailable or the encoding fails to load.
|
||||||
- ``warm_tiktoken_cache`` populates the cache on success.
|
- ``warm_tiktoken_cache`` populates the cache on success.
|
||||||
|
- An in-flight tiktoken load prevents duplicate blocking downloads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import threading
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from deerflow.agents.memory.prompt import (
|
from deerflow.agents.memory.prompt import (
|
||||||
_count_tokens,
|
_count_tokens,
|
||||||
_get_tiktoken_encoding,
|
_get_tiktoken_encoding,
|
||||||
_tiktoken_encoding_cache,
|
_tiktoken_encoding_cache,
|
||||||
|
format_memory_for_injection,
|
||||||
warm_tiktoken_cache,
|
warm_tiktoken_cache,
|
||||||
)
|
)
|
||||||
|
from deerflow.config.memory_config import MemoryConfig
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _get_tiktoken_encoding
|
# _get_tiktoken_encoding
|
||||||
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
|
|||||||
assert enc is fake_enc
|
assert enc is fake_enc
|
||||||
tiktoken.get_encoding.assert_not_called()
|
tiktoken.get_encoding.assert_not_called()
|
||||||
|
|
||||||
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
|
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
|
||||||
|
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
|
||||||
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||||
import tiktoken
|
import tiktoken
|
||||||
|
|
||||||
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
|
get_encoding = mock.Mock(side_effect=OSError("download failed"))
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
result = _get_tiktoken_encoding("bogus_encoding")
|
result = _get_tiktoken_encoding("bogus_encoding")
|
||||||
assert result is None
|
assert result is None
|
||||||
assert "bogus_encoding" not in _tiktoken_encoding_cache
|
# The failure is remembered as a (None, timestamp) tuple.
|
||||||
|
assert "bogus_encoding" in _tiktoken_encoding_cache
|
||||||
|
cached = _tiktoken_encoding_cache["bogus_encoding"]
|
||||||
|
assert isinstance(cached, tuple)
|
||||||
|
assert cached[0] is None
|
||||||
|
|
||||||
|
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
|
||||||
|
# the network download in restricted environments — see #3429).
|
||||||
|
result2 = _get_tiktoken_encoding("bogus_encoding")
|
||||||
|
assert result2 is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Cleanup module-level cache to avoid cross-test leakage.
|
||||||
|
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||||
|
|
||||||
|
def test_failure_self_heals_after_cooldown(self, monkeypatch):
|
||||||
|
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
|
||||||
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
# First call fails, second call (after cooldown) succeeds.
|
||||||
|
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
|
# Initial failure is cached.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Within the cooldown window: no retry, immediate fallback.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
# Simulate the cooldown having elapsed by ageing the cached timestamp.
|
||||||
|
from deerflow.agents.memory import prompt as prompt_module
|
||||||
|
|
||||||
|
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
|
||||||
|
_tiktoken_encoding_cache["flaky_encoding"] = (
|
||||||
|
None,
|
||||||
|
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Now the load is retried and recovers to accurate counting.
|
||||||
|
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
|
||||||
|
assert get_encoding.call_count == 2
|
||||||
|
|
||||||
|
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||||
|
|
||||||
|
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
|
||||||
|
"""Concurrent callers must not start duplicate blocking BPE downloads."""
|
||||||
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||||
|
import tiktoken
|
||||||
|
|
||||||
|
started = threading.Event()
|
||||||
|
release = threading.Event()
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
|
||||||
|
def slow_get_encoding(_name):
|
||||||
|
started.set()
|
||||||
|
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
|
||||||
|
return fake_enc
|
||||||
|
|
||||||
|
get_encoding = mock.Mock(side_effect=slow_get_encoding)
|
||||||
|
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||||
|
|
||||||
|
result: dict[str, object | None] = {}
|
||||||
|
|
||||||
|
def load_encoding():
|
||||||
|
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
|
||||||
|
|
||||||
|
thread = threading.Thread(target=load_encoding)
|
||||||
|
thread.start()
|
||||||
|
try:
|
||||||
|
assert started.wait(timeout=1), "slow get_encoding did not start"
|
||||||
|
|
||||||
|
# While the first call is still blocked, a second call should see
|
||||||
|
# the in-flight sentinel and fall back immediately instead of
|
||||||
|
# starting another potentially long network download.
|
||||||
|
assert _get_tiktoken_encoding("slow_encoding") is None
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
finally:
|
||||||
|
release.set()
|
||||||
|
thread.join(timeout=2)
|
||||||
|
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||||
|
|
||||||
|
assert result["encoding"] is fake_enc
|
||||||
|
assert get_encoding.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -115,6 +208,45 @@ class TestCountTokens:
|
|||||||
result = _count_tokens(text, encoding_name="test_enc")
|
result = _count_tokens(text, encoding_name="test_enc")
|
||||||
assert result == len(text) // 4
|
assert result == len(text) // 4
|
||||||
|
|
||||||
|
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
|
||||||
|
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
|
||||||
|
# Spy on both the encoding loader and tiktoken.get_encoding directly.
|
||||||
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||||
|
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
|
||||||
|
|
||||||
|
text = "Hello, world! This is a network-free count."
|
||||||
|
result = _count_tokens(text, use_tiktoken=False)
|
||||||
|
assert result == len(text) // 4
|
||||||
|
get_encoding_spy.assert_not_called()
|
||||||
|
loader_spy.assert_not_called()
|
||||||
|
|
||||||
|
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
|
||||||
|
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
|
||||||
|
|
||||||
|
CJK characters are ~2 chars/token, so the char-based estimate must not
|
||||||
|
under-fill the budget the way ``len(text) // 4`` would.
|
||||||
|
"""
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
# "User prefers concise answers" rendered in CJK (Chinese) characters.
|
||||||
|
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
|
||||||
|
result = _count_tokens(text)
|
||||||
|
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
|
||||||
|
assert result == len(text) // 2
|
||||||
|
assert result > len(text) // 4
|
||||||
|
|
||||||
|
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
|
||||||
|
"""Mixed-language text should apply the CJK density only to CJK chars."""
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
|
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
|
||||||
|
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
|
||||||
|
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
|
||||||
|
|
||||||
|
result = _count_tokens(text)
|
||||||
|
|
||||||
|
assert result == (len(text) - cjk) // 4 + cjk // 2
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# warm_tiktoken_cache
|
# warm_tiktoken_cache
|
||||||
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
|
|||||||
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
||||||
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||||
assert warm_tiktoken_cache() is False
|
assert warm_tiktoken_cache() is False
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# format_memory_for_injection token_counting strategy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestFormatMemoryForInjectionTokenCounting:
|
||||||
|
"""Verify the use_tiktoken flag is honoured end-to-end."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _sample_memory() -> dict:
|
||||||
|
return {
|
||||||
|
"facts": [
|
||||||
|
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
|
||||||
|
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
|
||||||
|
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
|
||||||
|
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||||
|
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||||
|
|
||||||
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
|
||||||
|
assert "User prefers concise answers." in result
|
||||||
|
get_encoding_spy.assert_not_called()
|
||||||
|
|
||||||
|
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
|
||||||
|
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
|
||||||
|
fake_enc = mock.Mock()
|
||||||
|
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
||||||
|
mock.Mock(return_value=fake_enc),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
|
||||||
|
assert "User prefers concise answers." in result
|
||||||
|
assert fake_enc.encode.called
|
||||||
|
|
||||||
|
def test_empty_memory_returns_empty(self):
|
||||||
|
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# MemoryConfig.token_counting
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestMemoryConfigTokenCounting:
|
||||||
|
"""Verify the new config field defaults and validation."""
|
||||||
|
|
||||||
|
def test_default_is_tiktoken(self):
|
||||||
|
"""Default must remain tiktoken so existing deployments are unaffected."""
|
||||||
|
assert MemoryConfig().token_counting == "tiktoken"
|
||||||
|
|
||||||
|
def test_accepts_char(self):
|
||||||
|
assert MemoryConfig(token_counting="char").token_counting == "char"
|
||||||
|
|
||||||
|
def test_rejects_invalid_value(self):
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
MemoryConfig(token_counting="invalid")
|
||||||
|
|||||||
Generated
+2
@@ -820,6 +820,7 @@ dependencies = [
|
|||||||
{ name = "agent-sandbox" },
|
{ name = "agent-sandbox" },
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "alembic" },
|
{ name = "alembic" },
|
||||||
|
{ name = "cryptography" },
|
||||||
{ name = "ddgs" },
|
{ name = "ddgs" },
|
||||||
{ name = "dotenv" },
|
{ name = "dotenv" },
|
||||||
{ name = "duckdb" },
|
{ name = "duckdb" },
|
||||||
@@ -871,6 +872,7 @@ requires-dist = [
|
|||||||
{ name = "aiosqlite", specifier = ">=0.19" },
|
{ name = "aiosqlite", specifier = ">=0.19" },
|
||||||
{ name = "alembic", specifier = ">=1.13" },
|
{ name = "alembic", specifier = ">=1.13" },
|
||||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||||
|
{ name = "cryptography", specifier = ">=43.0.0" },
|
||||||
{ name = "ddgs", specifier = ">=9.10.0" },
|
{ name = "ddgs", specifier = ">=9.10.0" },
|
||||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||||
{ name = "duckdb", specifier = ">=1.4.4" },
|
{ name = "duckdb", specifier = ">=1.4.4" },
|
||||||
|
|||||||
+54
-2
@@ -15,7 +15,7 @@
|
|||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Bump this number when the config schema changes.
|
# Bump this number when the config schema changes.
|
||||||
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
# Run `make config-upgrade` to merge new fields into your local config.yaml.
|
||||||
config_version: 11
|
config_version: 12
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Logging
|
# Logging
|
||||||
@@ -768,8 +768,12 @@ sandbox:
|
|||||||
allow_host_bash: false
|
allow_host_bash: false
|
||||||
# Optional: Mount additional host directories into the sandbox.
|
# Optional: Mount additional host directories into the sandbox.
|
||||||
# Each mount maps a host path to a virtual container path accessible by the agent.
|
# Each mount maps a host path to a virtual container path accessible by the agent.
|
||||||
|
# Note: with LocalSandboxProvider under `make up` (docker-compose), host_path is
|
||||||
|
# checked from inside the deer-flow-gateway container — you must also bind-mount
|
||||||
|
# the same directory into services.gateway.volumes in docker/docker-compose.yaml
|
||||||
|
# for this mount to take effect (see issue #3244).
|
||||||
# mounts:
|
# mounts:
|
||||||
# - host_path: /home/user/my-project # Absolute path on the host machine
|
# - host_path: /home/user/my-project # Absolute path; see note above for Docker mode
|
||||||
# container_path: /mnt/my-project # Virtual path inside the sandbox
|
# container_path: /mnt/my-project # Virtual path inside the sandbox
|
||||||
# read_only: true # Whether the mount is read-only (default: false)
|
# read_only: true # Whether the mount is read-only (default: false)
|
||||||
|
|
||||||
@@ -1020,6 +1024,15 @@ memory:
|
|||||||
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
|
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
|
||||||
injection_enabled: true # Whether to inject memory into system prompt
|
injection_enabled: true # Whether to inject memory into system prompt
|
||||||
max_injection_tokens: 2000 # Maximum tokens for memory injection
|
max_injection_tokens: 2000 # Maximum tokens for memory injection
|
||||||
|
# Token counting strategy for memory-injection budgeting:
|
||||||
|
# tiktoken (default) - accurate, but the encoding's BPE data may be
|
||||||
|
# downloaded from a public network endpoint on first use. In
|
||||||
|
# network-restricted environments this download can block for a long
|
||||||
|
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
|
||||||
|
# to "char" to avoid it.
|
||||||
|
# char - network-free CJK-aware character-based estimate; never touches
|
||||||
|
# tiktoken. Slightly less precise budgeting, zero network I/O.
|
||||||
|
token_counting: tiktoken
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# Custom Agent Management API
|
# Custom Agent Management API
|
||||||
@@ -1127,6 +1140,45 @@ run_events:
|
|||||||
max_trace_content: 10240
|
max_trace_content: 10240
|
||||||
track_token_usage: true
|
track_token_usage: true
|
||||||
|
|
||||||
|
# ============================================================================
|
||||||
|
# User-Owned IM Channel Connections
|
||||||
|
# ============================================================================
|
||||||
|
# Lets logged-in users connect their own IM accounts from the DeerFlow frontend
|
||||||
|
# while reusing the existing `channels` runtime configuration below.
|
||||||
|
#
|
||||||
|
# Security notes:
|
||||||
|
# - No public IP, OAuth callback URL, or provider webhook is required.
|
||||||
|
# - Provider bot/app credentials stay under `channels.*`.
|
||||||
|
# - `channel_connections` stores per-user bindings and one-time connect codes.
|
||||||
|
# - Telegram uses a deep link when `bot_username` is configured.
|
||||||
|
# - Slack, Discord, Feishu, DingTalk, WeChat, and WeCom use `/connect <code>`
|
||||||
|
# through the already-running bot/app.
|
||||||
|
#
|
||||||
|
# channel_connections:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# telegram:
|
||||||
|
# enabled: false
|
||||||
|
# bot_username: $TELEGRAM_BOT_USERNAME
|
||||||
|
#
|
||||||
|
# slack:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# discord:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# feishu:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# dingtalk:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# wechat:
|
||||||
|
# enabled: false
|
||||||
|
#
|
||||||
|
# wecom:
|
||||||
|
# enabled: false
|
||||||
|
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
# IM Channels Configuration
|
# IM Channels Configuration
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
|
|||||||
@@ -72,7 +72,13 @@ services:
|
|||||||
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
UV_INDEX_URL: ${UV_INDEX_URL:-https://pypi.org/simple}
|
||||||
UV_EXTRAS: ${UV_EXTRAS:-}
|
UV_EXTRAS: ${UV_EXTRAS:-}
|
||||||
container_name: deer-flow-gateway
|
container_name: deer-flow-gateway
|
||||||
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-4}"
|
# Gateway hosts the agent runtime with in-process RunManager + StreamBridge
|
||||||
|
# singletons -- run state lives in this worker's memory. Default to a single
|
||||||
|
# worker: with >1 worker and no nginx sticky sessions, run cancel, SSE
|
||||||
|
# reconnect, request dedup, and per-worker IM channel services all break
|
||||||
|
# across workers until a shared (e.g. redis) stream bridge lands, which is
|
||||||
|
# not yet implemented. Override GATEWAY_WORKERS only once that is in place.
|
||||||
|
command: sh -c "cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 --workers ${GATEWAY_WORKERS:-1}"
|
||||||
volumes:
|
volumes:
|
||||||
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
- ${DEER_FLOW_CONFIG_PATH}:/app/backend/config.yaml:ro
|
||||||
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
- ${DEER_FLOW_EXTENSIONS_CONFIG_PATH}:/app/backend/extensions_config.json:ro
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ import { defineConfig, devices } from "@playwright/test";
|
|||||||
* so the mock-based suite is untouched.
|
* so the mock-based suite is untouched.
|
||||||
*
|
*
|
||||||
* Two webServers are started: the replay gateway (:8011) and the frontend
|
* Two webServers are started: the replay gateway (:8011) and the frontend
|
||||||
* (:3000, pointed at the gateway). Auth uses a throwaway test account the spec
|
* (:3000, pointed at the gateway). Auth-disabled mode is enabled on both
|
||||||
* registers at runtime — no secrets.
|
* servers so the no-cookie e2e contract is covered; specs that need session
|
||||||
|
* cookies still register a throwaway test account at runtime.
|
||||||
*/
|
*/
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
testDir: "./tests/e2e-real-backend",
|
testDir: "./tests/e2e-real-backend",
|
||||||
@@ -38,7 +39,10 @@ export default defineConfig({
|
|||||||
// Mount the test-only run/message seeder used by multi-run-order.spec.ts
|
// Mount the test-only run/message seeder used by multi-run-order.spec.ts
|
||||||
// (#3352). The endpoint exists only on this replay gateway, never in the
|
// (#3352). The endpoint exists only on this replay gateway, never in the
|
||||||
// production app.
|
// production app.
|
||||||
env: { DEERFLOW_ENABLE_TEST_SEED: "1" },
|
env: {
|
||||||
|
DEERFLOW_ENABLE_TEST_SEED: "1",
|
||||||
|
DEER_FLOW_AUTH_DISABLED: "1",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
command: "pnpm build && pnpm start",
|
command: "pnpm build && pnpm start",
|
||||||
|
|||||||
@@ -1,34 +1,77 @@
|
|||||||
"use client";
|
"use client";
|
||||||
|
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useEffect, useMemo, useState } from "react";
|
import { useEffect, useMemo, useRef, useState } from "react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
import { Input } from "@/components/ui/input";
|
import { Input } from "@/components/ui/input";
|
||||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||||
|
import {
|
||||||
|
ThreadChannelBadge,
|
||||||
|
ThreadChannelIcon,
|
||||||
|
} from "@/components/workspace/thread-channel-source";
|
||||||
import {
|
import {
|
||||||
WorkspaceBody,
|
WorkspaceBody,
|
||||||
WorkspaceContainer,
|
WorkspaceContainer,
|
||||||
WorkspaceHeader,
|
WorkspaceHeader,
|
||||||
} from "@/components/workspace/workspace-container";
|
} from "@/components/workspace/workspace-container";
|
||||||
import { useI18n } from "@/core/i18n/hooks";
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
import { useThreads } from "@/core/threads/hooks";
|
import { useInfiniteThreads } from "@/core/threads/hooks";
|
||||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
import {
|
||||||
|
channelSourceOfThread,
|
||||||
|
pathOfThread,
|
||||||
|
titleOfThread,
|
||||||
|
} from "@/core/threads/utils";
|
||||||
import { formatTimeAgo } from "@/core/utils/datetime";
|
import { formatTimeAgo } from "@/core/utils/datetime";
|
||||||
|
|
||||||
export default function ChatsPage() {
|
export default function ChatsPage() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const { data: threads } = useThreads();
|
const {
|
||||||
|
data: infiniteThreads,
|
||||||
|
fetchNextPage,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
} = useInfiniteThreads();
|
||||||
|
const threads = useMemo(
|
||||||
|
() => infiniteThreads?.pages.flat() ?? [],
|
||||||
|
[infiniteThreads],
|
||||||
|
);
|
||||||
const [search, setSearch] = useState("");
|
const [search, setSearch] = useState("");
|
||||||
|
const isSearching = search.trim().length > 0;
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
document.title = `${t.pages.chats} - ${t.pages.appName}`;
|
document.title = `${t.pages.chats} - ${t.pages.appName}`;
|
||||||
}, [t.pages.chats, t.pages.appName]);
|
}, [t.pages.chats, t.pages.appName]);
|
||||||
|
|
||||||
const filteredThreads = useMemo(() => {
|
const filteredThreads = useMemo(() => {
|
||||||
return threads?.filter((thread) => {
|
return threads.filter((thread) => {
|
||||||
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
|
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
|
||||||
});
|
});
|
||||||
}, [threads, search]);
|
}, [threads, search]);
|
||||||
|
|
||||||
|
// Sentinel-based auto load-more for the unfiltered list (issue #3482).
|
||||||
|
// In search mode we deliberately do NOT auto-paginate, otherwise an empty
|
||||||
|
// filtered view would keep the sentinel in the viewport and drain the
|
||||||
|
// entire backend list one page at a time. Searching falls back to an
|
||||||
|
// explicit button so users can still reach older conversations on demand.
|
||||||
|
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const element = sentinelRef.current;
|
||||||
|
if (!element || !hasNextPage || isSearching) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const observer = new IntersectionObserver(
|
||||||
|
([entry]) => {
|
||||||
|
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||||
|
void fetchNextPage();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ rootMargin: "200px 0px 200px 0px" },
|
||||||
|
);
|
||||||
|
observer.observe(element);
|
||||||
|
return () => observer.disconnect();
|
||||||
|
}, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<WorkspaceContainer>
|
<WorkspaceContainer>
|
||||||
<WorkspaceHeader></WorkspaceHeader>
|
<WorkspaceHeader></WorkspaceHeader>
|
||||||
@@ -47,20 +90,52 @@ export default function ChatsPage() {
|
|||||||
<main className="min-h-0 flex-1">
|
<main className="min-h-0 flex-1">
|
||||||
<ScrollArea className="size-full py-4">
|
<ScrollArea className="size-full py-4">
|
||||||
<div className="mx-auto flex size-full max-w-(--container-width-md) flex-col">
|
<div className="mx-auto flex size-full max-w-(--container-width-md) flex-col">
|
||||||
{filteredThreads?.map((thread) => (
|
{filteredThreads.map((thread) => {
|
||||||
<Link key={thread.thread_id} href={pathOfThread(thread)}>
|
const channelSource = channelSourceOfThread(thread);
|
||||||
<div className="flex flex-col gap-2 border-b p-4">
|
return (
|
||||||
<div>
|
<Link key={thread.thread_id} href={pathOfThread(thread)}>
|
||||||
<div>{titleOfThread(thread)}</div>
|
<div className="flex flex-col gap-2 border-b p-4">
|
||||||
</div>
|
<div className="flex min-w-0 items-center gap-2">
|
||||||
{thread.updated_at && (
|
<ThreadChannelIcon source={channelSource} />
|
||||||
<div className="text-muted-foreground text-sm">
|
<div className="min-w-0 flex-1 truncate">
|
||||||
{formatTimeAgo(thread.updated_at)}
|
{titleOfThread(thread)}
|
||||||
|
</div>
|
||||||
|
<ThreadChannelBadge
|
||||||
|
source={channelSource}
|
||||||
|
className="hidden sm:inline-flex"
|
||||||
|
/>
|
||||||
</div>
|
</div>
|
||||||
)}
|
{thread.updated_at && (
|
||||||
</div>
|
<div className="text-muted-foreground text-sm">
|
||||||
</Link>
|
{formatTimeAgo(thread.updated_at)}
|
||||||
))}
|
</div>
|
||||||
|
)}
|
||||||
|
</div>
|
||||||
|
</Link>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
{hasNextPage && !isSearching && (
|
||||||
|
<div
|
||||||
|
ref={sentinelRef}
|
||||||
|
aria-hidden="true"
|
||||||
|
className="h-px w-full"
|
||||||
|
data-testid="chats-page-sentinel"
|
||||||
|
/>
|
||||||
|
)}
|
||||||
|
{hasNextPage && isSearching && (
|
||||||
|
<div className="flex justify-center p-4">
|
||||||
|
<Button
|
||||||
|
variant="outline"
|
||||||
|
onClick={() => void fetchNextPage()}
|
||||||
|
disabled={isFetchingNextPage}
|
||||||
|
data-testid="chats-page-load-more"
|
||||||
|
>
|
||||||
|
{isFetchingNextPage
|
||||||
|
? t.chats.loadingMore
|
||||||
|
: t.chats.loadMoreToSearch}
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
</main>
|
</main>
|
||||||
|
|||||||
@@ -0,0 +1,184 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { MessageCircleIcon } from "lucide-react";
|
||||||
|
import type { SVGProps } from "react";
|
||||||
|
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
type ChannelProviderIconProps = SVGProps<SVGSVGElement> & {
|
||||||
|
provider: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ChannelProviderIcon({
|
||||||
|
provider,
|
||||||
|
className,
|
||||||
|
...props
|
||||||
|
}: ChannelProviderIconProps) {
|
||||||
|
const normalizedProvider = provider.toLowerCase();
|
||||||
|
|
||||||
|
if (normalizedProvider === "telegram") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<circle cx="12" cy="12" r="11" fill="#2AABEE" />
|
||||||
|
<path
|
||||||
|
fill="#FFFFFF"
|
||||||
|
d="M17.4 7.2 15.7 16c-.1.7-.5.9-1 .6l-2.8-2.1-1.4 1.3c-.1.2-.3.3-.6.3l.2-2.9 5.3-4.8c.2-.2 0-.3-.3-.1l-6.6 4.1-2.8-.9c-.6-.2-.6-.6.1-.8l10.9-4.2c.5-.2.9.1.7.7Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "slack") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 256 256"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<path
|
||||||
|
fill="#E01E5A"
|
||||||
|
d="M53.841 161.32c0 14.832-11.987 26.82-26.819 26.82S.203 176.152.203 161.32c0-14.831 11.987-26.818 26.82-26.818H53.84zm13.41 0c0-14.831 11.987-26.818 26.819-26.818s26.819 11.987 26.819 26.819v67.047c0 14.832-11.987 26.82-26.82 26.82c-14.83 0-26.818-11.988-26.818-26.82z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
fill="#36C5F0"
|
||||||
|
d="M94.07 53.638c-14.832 0-26.82-11.987-26.82-26.819S79.239 0 94.07 0s26.819 11.987 26.819 26.819v26.82zm0 13.613c14.832 0 26.819 11.987 26.819 26.819s-11.987 26.819-26.82 26.819H26.82C11.987 120.889 0 108.902 0 94.069c0-14.83 11.987-26.818 26.819-26.818z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
fill="#2EB67D"
|
||||||
|
d="M201.55 94.07c0-14.832 11.987-26.82 26.818-26.82s26.82 11.988 26.82 26.82s-11.988 26.819-26.82 26.819H201.55zm-13.41 0c0 14.832-11.988 26.819-26.82 26.819c-14.831 0-26.818-11.987-26.818-26.82V26.82C134.502 11.987 146.489 0 161.32 0s26.819 11.987 26.819 26.819z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
fill="#ECB22E"
|
||||||
|
d="M161.32 201.55c14.832 0 26.82 11.987 26.82 26.818s-11.988 26.82-26.82 26.82c-14.831 0-26.818-11.988-26.818-26.82V201.55zm0-13.41c-14.831 0-26.818-11.988-26.818-26.82c0-14.831 11.987-26.818 26.819-26.818h67.25c14.832 0 26.82 11.987 26.82 26.819s-11.988 26.819-26.82 26.819z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "discord") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<circle cx="12" cy="12" r="11" fill="#5865F2" />
|
||||||
|
<path
|
||||||
|
fill="#FFFFFF"
|
||||||
|
d="M8.1 8.4c1.4-.6 2.7-.7 3.9-.7s2.5.1 3.9.7c1 1.5 1.5 3.1 1.4 4.8-.9.7-1.8 1.1-2.8 1.3l-.7-1.1c.4-.1.7-.3 1.1-.5-.3.1-.6.3-.9.4-.7.3-1.4.4-2 .4s-1.3-.1-2-.4c-.3-.1-.6-.2-.9-.4.3.2.7.4 1.1.5l-.7 1.1c-1-.2-1.9-.6-2.8-1.3-.1-1.7.4-3.3 1.4-4.8Zm2.1 3.9c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Zm3.6 0c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "feishu") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<rect
|
||||||
|
x="1.25"
|
||||||
|
y="1.25"
|
||||||
|
width="21.5"
|
||||||
|
height="21.5"
|
||||||
|
rx="5.25"
|
||||||
|
fill="#FFFFFF"
|
||||||
|
stroke="#E5E7EB"
|
||||||
|
strokeWidth=".5"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M6.1 4.5h8.3c.9 0 1.7.4 2.2 1.1a16 16 0 0 1 2.9 6.2c-1.8-.8-3.8-.9-5.9-.3L5.6 5.8c-.6-.5-.3-1.3.5-1.3Z"
|
||||||
|
fill="#14D6C5"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M3.2 8.9c3.6 3.4 7.5 5.7 11.7 6.8 2.7.7 5.1.4 7-.6-1.6 3.1-5.2 5.4-9.4 5.4-3.4 0-6.7-.9-9.2-2.6a2 2 0 0 1-.9-1.7V9.6c0-.7.4-1 .8-.7Z"
|
||||||
|
fill="#3370FF"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
d="M11 14.1c2.3-3.1 6.1-4.6 10.5-3.3l.8.2-2.6 4.1a6.3 6.3 0 0 1-6 2.9c-1.9-.2-3.9-.8-5.9-1.7 1.1-.7 2.2-1.4 3.2-2.2Z"
|
||||||
|
fill="#1E3A9F"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "dingtalk") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 1024 1024"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<g transform="translate(512 512) scale(1.35) translate(-512 -512)">
|
||||||
|
<path
|
||||||
|
fill="#0B86FF"
|
||||||
|
d="M739 449.3c-1 4.2-3.5 10.4-7 17.8h.1l-.4.7c-20.3 43.1-73.1 127.7-73.1 127.7l-.3-.5-15.5 26.8h74.5L575.1 810l32.3-128h-58.6l20.4-84.7c-16.5 3.9-35.9 9.4-59 16.8 0 0-31.2 18.2-89.9-35 0 0-39.6-34.7-16.6-43.4 9.8-3.7 47.4-8.4 77-12.3 40-5.4 64.6-8.2 64.6-8.2S422 517 392.7 512.5c-29.3-4.6-66.4-53.1-74.3-95.8 0 0-12.2-23.4 26.3-12.3s197.9 43.2 197.9 43.2-207.4-63.3-221.2-78.7-40.6-84.2-37.1-126.5c0 0 1.5-10.5 12.4-7.7 0 0 153.3 69.7 258.1 107.9 104.8 37.9 195.9 57.3 184.2 106.7Z"
|
||||||
|
/>
|
||||||
|
</g>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "wechat") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<circle cx="12" cy="12" r="11" fill="#07C160" />
|
||||||
|
<path
|
||||||
|
fill="#FFFFFF"
|
||||||
|
d="M10.4 6.5c-3 0-5.4 2-5.4 4.5 0 1.4.8 2.7 2.1 3.5l-.5 1.8 2-.9c.6.1 1.2.2 1.8.2 3 0 5.4-2 5.4-4.5s-2.4-4.6-5.4-4.6Zm-1.9 3.7a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Zm3.7 0a.7.7 0 1 1 0-1.4.7.7 0 0 1 0 1.4Z"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
fill="#FFFFFF"
|
||||||
|
fillOpacity=".86"
|
||||||
|
d="M14.4 12.3c2.5 0 4.6 1.7 4.6 3.8 0 1.1-.6 2.1-1.6 2.8l.4 1.5-1.7-.8c-.5.1-1.1.2-1.7.2-2.5 0-4.6-1.7-4.6-3.8s2.1-3.7 4.6-3.7Zm-1.6 3.1a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Zm3.1 0a.6.6 0 1 0 0-1.2.6.6 0 0 0 0 1.2Z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedProvider === "wecom") {
|
||||||
|
return (
|
||||||
|
<svg
|
||||||
|
viewBox="0 0 24 24"
|
||||||
|
aria-hidden="true"
|
||||||
|
className={cn("size-5", className)}
|
||||||
|
{...props}
|
||||||
|
>
|
||||||
|
<rect
|
||||||
|
x="1.25"
|
||||||
|
y="1.25"
|
||||||
|
width="21.5"
|
||||||
|
height="21.5"
|
||||||
|
rx="5.25"
|
||||||
|
fill="#FFFFFF"
|
||||||
|
stroke="#E5E7EB"
|
||||||
|
strokeWidth=".5"
|
||||||
|
/>
|
||||||
|
<path
|
||||||
|
fill="#168DEB"
|
||||||
|
d="m17.326 8.158-.003-.007a6.6 6.6 0 0 0-1.178-1.674c-1.266-1.307-3.067-2.19-5.102-2.417a9.3 9.3 0 0 0-2.124 0h-.001c-2.061.228-3.882 1.107-5.14 2.405a6.7 6.7 0 0 0-1.194 1.682A5.7 5.7 0 0 0 2 10.657c0 1.106.332 2.218.988 3.201l.006.01c.391.594 1.092 1.39 1.637 1.83l.983.793-.208.875.527-.267.708-.358.761.225c.467.137.955.227 1.517.29h.005q.515.06 1.026.059c.355 0 .724-.02 1.095-.06a9 9 0 0 0 1.346-.258c.095.7.43 1.337.932 1.81-.658.208-1.352.358-2.061.436-.442.048-.883.072-1.312.072q-.627 0-1.253-.072a10.7 10.7 0 0 1-1.861-.36l-2.84 1.438s-.29.131-.44.131c-.418 0-.702-.285-.702-.704 0-.252.067-.598.128-.84l.394-1.653c-.728-.586-1.563-1.544-2.052-2.287A7.76 7.76 0 0 1 0 10.658a7.7 7.7 0 0 1 .787-3.39 8.7 8.7 0 0 1 1.551-2.19c1.61-1.665 3.878-2.73 6.359-3.006a11.3 11.3 0 0 1 2.565 0c2.47.275 4.712 1.353 6.323 3.017a8.6 8.6 0 0 1 1.539 2.192c.466.945.769 1.937.769 2.978a3.06 3.06 0 0 0-2-.005c-.001-.644-.189-1.329-.564-2.09zm4.125 6.977-.024-.024-.024-.018-.024-.018-.096-.095a4.24 4.24 0 0 1-1.169-2.192q0-.038-.006-.075l-.006-.056-.035-.144a1.3 1.3 0 0 0-.358-.61 1.386 1.386 0 0 0-1.957 0 1.4 1.4 0 0 0 0 1.963c.191.191.418.311.668.371.024.012.06.012.084.012q.019 0 .041.006.023.005.042.006a4.24 4.24 0 0 1 2.231 1.186c.048.048.096.095.131.143a.323.323 0 0 0 .466 0 .35.35 0 0 0 .036-.455m-1.05 4.37-.025.025c-.119.096-.31.096-.453-.036a.326.326 0 0 1 0-.467c.047-.036.094-.083.141-.13l.002-.002a4.27 4.27 0 0 0 1.187-2.28q.005-.024.006-.043c0-.024 0-.06.012-.084a1.386 1.386 0 0 1 2.326-.67 1.4 1.4 0 0 1 0 1.964c-.167.18-.382.299-.608.359l-.143.036-.057.005q-.035.006-.075.007a4.2 4.2 0 0 0-2.183 1.173l-.095.096q-.009.01-.018.024t-.018.024m-4.392-1.053.024.024.024.018q.015.009.024.018l.096.096a4.25 4.25 0 0 1 1.169 2.19q0 .04.006.076.005.03.006.057l.035.143c.06.228.18.443.358.611.537.539 1.42.539 1.957 0a1.4 1.4 0 0 0 0-1.964 1.4 1.4 0 0 0-.668-.371c-.024-.012-.06-.012-.084-.012q-.018 0-.041-.006l-.042-.006a4.25 4.25 0 0 1-2.231-1.185 1.4 1.4 0 0 1-.131-.144.323.323 0 0 0-.466 0 .325.325 0 0 0-.036.455m1.039-4.358.024-.024a.32.32 0 0 1 .453.035.326.326 0 0 1 0 .467c-.047.036-.094.083-.141.13l-.002.002a4.27 4.27 0 0 0-1.187 2.281l-.006.042c0 .024 0 .06-.012.084a1.386 1.386 0 0 1-2.326.67 1.4 1.4 0 0 1 0-1.963c.166-.18.381-.3.608-.36l.143-.035q.026 0 .056-.006.037-.005.075-.006a4.2 4.2 0 0 0 2.183-1.174l.096-.095.018-.025z"
|
||||||
|
/>
|
||||||
|
</svg>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<MessageCircleIcon aria-hidden="true" className={cn("size-5", className)} />
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,159 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { LoaderCircleIcon } from "lucide-react";
|
||||||
|
import {
|
||||||
|
type CSSProperties,
|
||||||
|
type FormEvent,
|
||||||
|
useEffect,
|
||||||
|
useMemo,
|
||||||
|
useState,
|
||||||
|
} from "react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Dialog,
|
||||||
|
DialogContent,
|
||||||
|
DialogDescription,
|
||||||
|
DialogFooter,
|
||||||
|
DialogHeader,
|
||||||
|
DialogTitle,
|
||||||
|
} from "@/components/ui/dialog";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import type {
|
||||||
|
ChannelProvider,
|
||||||
|
ChannelRuntimeConfigValues,
|
||||||
|
} from "@/core/channels/types";
|
||||||
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
|
|
||||||
|
type ChannelRuntimeConfigDialogProps = {
|
||||||
|
provider: ChannelProvider | null;
|
||||||
|
open: boolean;
|
||||||
|
submitting: boolean;
|
||||||
|
onOpenChange: (open: boolean) => void;
|
||||||
|
onSubmit: (
|
||||||
|
provider: ChannelProvider,
|
||||||
|
values: ChannelRuntimeConfigValues,
|
||||||
|
) => void;
|
||||||
|
};
|
||||||
|
|
||||||
|
type SecretInputStyle = CSSProperties & {
|
||||||
|
WebkitTextSecurity?: "disc";
|
||||||
|
};
|
||||||
|
|
||||||
|
const SECRET_INPUT_STYLE: SecretInputStyle = {
|
||||||
|
WebkitTextSecurity: "disc",
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ChannelRuntimeConfigDialog({
|
||||||
|
provider,
|
||||||
|
open,
|
||||||
|
submitting,
|
||||||
|
onOpenChange,
|
||||||
|
onSubmit,
|
||||||
|
}: ChannelRuntimeConfigDialogProps) {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const [values, setValues] = useState<ChannelRuntimeConfigValues>({});
|
||||||
|
const fields = useMemo(
|
||||||
|
() => provider?.credential_fields ?? [],
|
||||||
|
[provider?.credential_fields],
|
||||||
|
);
|
||||||
|
const credentialValues = useMemo<ChannelRuntimeConfigValues>(
|
||||||
|
() => provider?.credential_values ?? {},
|
||||||
|
[provider?.credential_values],
|
||||||
|
);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
if (!open || !provider) {
|
||||||
|
setValues({});
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
setValues(
|
||||||
|
Object.fromEntries(
|
||||||
|
fields.map((field) => [field.name, credentialValues[field.name] ?? ""]),
|
||||||
|
) as ChannelRuntimeConfigValues,
|
||||||
|
);
|
||||||
|
}, [credentialValues, fields, open, provider]);
|
||||||
|
|
||||||
|
if (!provider) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
const isEditing = provider.configured;
|
||||||
|
|
||||||
|
const handleSubmit = (event: FormEvent<HTMLFormElement>) => {
|
||||||
|
event.preventDefault();
|
||||||
|
onSubmit(provider, values);
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<Dialog open={open} onOpenChange={onOpenChange}>
|
||||||
|
<DialogContent>
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
|
<DialogHeader>
|
||||||
|
<DialogTitle>
|
||||||
|
{isEditing
|
||||||
|
? t.channels.setupEditTitle(provider.display_name)
|
||||||
|
: t.channels.setupTitle(provider.display_name)}
|
||||||
|
</DialogTitle>
|
||||||
|
<DialogDescription>{t.channels.setupDescription}</DialogDescription>
|
||||||
|
</DialogHeader>
|
||||||
|
|
||||||
|
<div className="space-y-3">
|
||||||
|
{fields.map((field) => {
|
||||||
|
const inputId = `channel-${provider.provider}-${field.name}`;
|
||||||
|
const isSecretField = field.type === "password";
|
||||||
|
return (
|
||||||
|
<div key={field.name} className="space-y-1.5">
|
||||||
|
<label
|
||||||
|
htmlFor={inputId}
|
||||||
|
className="text-sm leading-none font-medium"
|
||||||
|
>
|
||||||
|
{field.label}
|
||||||
|
</label>
|
||||||
|
<Input
|
||||||
|
id={inputId}
|
||||||
|
type="text"
|
||||||
|
value={values[field.name] ?? ""}
|
||||||
|
required={field.required}
|
||||||
|
autoComplete="off"
|
||||||
|
autoCorrect="off"
|
||||||
|
autoCapitalize="none"
|
||||||
|
spellCheck={false}
|
||||||
|
data-1p-ignore={isSecretField ? "true" : undefined}
|
||||||
|
data-bwignore={isSecretField ? "true" : undefined}
|
||||||
|
data-form-type={isSecretField ? "other" : undefined}
|
||||||
|
data-lpignore={isSecretField ? "true" : undefined}
|
||||||
|
style={isSecretField ? SECRET_INPUT_STYLE : undefined}
|
||||||
|
onChange={(event) => {
|
||||||
|
setValues((current) => ({
|
||||||
|
...current,
|
||||||
|
[field.name]: event.target.value,
|
||||||
|
}));
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<DialogFooter>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
disabled={submitting}
|
||||||
|
onClick={() => onOpenChange(false)}
|
||||||
|
>
|
||||||
|
{t.common.cancel}
|
||||||
|
</Button>
|
||||||
|
<Button type="submit" disabled={submitting}>
|
||||||
|
{submitting ? (
|
||||||
|
<LoaderCircleIcon className="animate-spin" />
|
||||||
|
) : null}
|
||||||
|
{isEditing ? t.channels.saveChanges : t.channels.saveAndConnect}
|
||||||
|
</Button>
|
||||||
|
</DialogFooter>
|
||||||
|
</form>
|
||||||
|
</DialogContent>
|
||||||
|
</Dialog>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { CheckIcon, LoaderCircleIcon } from "lucide-react";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
SidebarGroup,
|
||||||
|
SidebarGroupLabel,
|
||||||
|
SidebarMenu,
|
||||||
|
SidebarMenuItem,
|
||||||
|
useSidebar,
|
||||||
|
} from "@/components/ui/sidebar";
|
||||||
|
import { Skeleton } from "@/components/ui/skeleton";
|
||||||
|
import {
|
||||||
|
useConfigureChannelProvider,
|
||||||
|
useChannelProviders,
|
||||||
|
useConnectChannelProvider,
|
||||||
|
} from "@/core/channels/hooks";
|
||||||
|
import {
|
||||||
|
closeConnectWindow,
|
||||||
|
openConnectUrl,
|
||||||
|
prepareConnectWindow,
|
||||||
|
} from "@/core/channels/open-connect-url";
|
||||||
|
import type { ChannelProvider } from "@/core/channels/types";
|
||||||
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
import { ChannelProviderIcon } from "./channel-provider-icon";
|
||||||
|
import { ChannelRuntimeConfigDialog } from "./channel-runtime-config-dialog";
|
||||||
|
|
||||||
|
function providerCanConnect(provider: ChannelProvider): boolean {
|
||||||
|
return (
|
||||||
|
(provider.connectable ?? (provider.enabled && provider.configured)) &&
|
||||||
|
provider.connection_status !== "connected"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerCanEditRuntimeConfig(provider: ChannelProvider): boolean {
|
||||||
|
return provider.enabled && (provider.credential_fields?.length ?? 0) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProviderUnavailableReason(
|
||||||
|
provider: ChannelProvider,
|
||||||
|
t: ReturnType<typeof useI18n>["t"],
|
||||||
|
): string | undefined {
|
||||||
|
if (provider.unavailable_reason) {
|
||||||
|
return provider.unavailable_reason;
|
||||||
|
}
|
||||||
|
if (!provider.enabled) {
|
||||||
|
return t.channels.disabled;
|
||||||
|
}
|
||||||
|
if (!provider.configured) {
|
||||||
|
return t.channels.unconfigured;
|
||||||
|
}
|
||||||
|
return provider.unavailable_reason ?? undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerNeedsRuntimeConfig(provider: ChannelProvider): boolean {
|
||||||
|
return (
|
||||||
|
provider.enabled &&
|
||||||
|
!provider.configured &&
|
||||||
|
(provider.credential_fields?.length ?? 0) > 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function WorkspaceChannelsList() {
|
||||||
|
const { open: isSidebarOpen } = useSidebar();
|
||||||
|
const { t } = useI18n();
|
||||||
|
const { enabled, providers, isLoading, error } = useChannelProviders();
|
||||||
|
const connectMutation = useConnectChannelProvider();
|
||||||
|
const configureMutation = useConfigureChannelProvider();
|
||||||
|
const [setupProvider, setSetupProvider] = useState<ChannelProvider | null>(
|
||||||
|
null,
|
||||||
|
);
|
||||||
|
const visibleProviders = providers.filter((provider) => provider.enabled);
|
||||||
|
|
||||||
|
const startConnect = (
|
||||||
|
provider: ChannelProvider,
|
||||||
|
preparedWindow?: Window | null,
|
||||||
|
) => {
|
||||||
|
const connectWindow =
|
||||||
|
preparedWindow !== undefined
|
||||||
|
? preparedWindow
|
||||||
|
: provider.auth_mode === "deep_link"
|
||||||
|
? prepareConnectWindow()
|
||||||
|
: null;
|
||||||
|
void connectMutation
|
||||||
|
.mutateAsync(provider.provider)
|
||||||
|
.then((result) => {
|
||||||
|
if (result.url) {
|
||||||
|
openConnectUrl(result.url, connectWindow);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
closeConnectWindow(connectWindow);
|
||||||
|
toast.success(result.instruction);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
closeConnectWindow(connectWindow);
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error ? error.message : t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
if (!isSidebarOpen) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isLoading) {
|
||||||
|
return (
|
||||||
|
<SidebarGroup className="pt-0">
|
||||||
|
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
|
||||||
|
<div className="space-y-2 px-2 py-1">
|
||||||
|
<Skeleton className="h-8 w-full" />
|
||||||
|
<Skeleton className="h-8 w-full" />
|
||||||
|
<Skeleton className="h-8 w-full" />
|
||||||
|
</div>
|
||||||
|
</SidebarGroup>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (error || !enabled || visibleProviders.length === 0) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SidebarGroup className="pt-0">
|
||||||
|
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
|
||||||
|
<SidebarMenu>
|
||||||
|
{visibleProviders.map((provider) => {
|
||||||
|
const canEditRuntimeConfig = providerCanEditRuntimeConfig(provider);
|
||||||
|
const isConnected = provider.connection_status === "connected";
|
||||||
|
const isPending =
|
||||||
|
(connectMutation.isPending &&
|
||||||
|
connectMutation.variables === provider.provider) ||
|
||||||
|
(configureMutation.isPending &&
|
||||||
|
configureMutation.variables?.provider === provider.provider);
|
||||||
|
const canConnect = providerCanConnect(provider);
|
||||||
|
const unavailableReason = getProviderUnavailableReason(provider, t);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SidebarMenuItem key={provider.provider}>
|
||||||
|
<div className="hover:bg-sidebar-accent flex h-10 items-center gap-2 rounded-md px-2 transition-colors">
|
||||||
|
<ChannelProviderIcon
|
||||||
|
provider={provider.provider}
|
||||||
|
className="size-5 shrink-0"
|
||||||
|
/>
|
||||||
|
<span className="min-w-0 flex-1 truncate text-sm font-medium">
|
||||||
|
{provider.display_name}
|
||||||
|
</span>
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="sm"
|
||||||
|
variant={isConnected ? "outline" : "secondary"}
|
||||||
|
className={cn(
|
||||||
|
"h-8 w-24 px-2 text-xs",
|
||||||
|
isConnected && "gap-1",
|
||||||
|
)}
|
||||||
|
disabled={isPending}
|
||||||
|
title={unavailableReason}
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
providerNeedsRuntimeConfig(provider) ||
|
||||||
|
canEditRuntimeConfig
|
||||||
|
) {
|
||||||
|
setSetupProvider(provider);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!canConnect) {
|
||||||
|
toast.error(unavailableReason ?? t.channels.unavailable);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
startConnect(provider);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{isPending ? (
|
||||||
|
<LoaderCircleIcon className="size-3.5 animate-spin" />
|
||||||
|
) : isConnected ? (
|
||||||
|
<CheckIcon className="size-3.5" />
|
||||||
|
) : null}
|
||||||
|
<span>
|
||||||
|
{isConnected ? t.channels.connected : t.channels.connect}
|
||||||
|
</span>
|
||||||
|
</Button>
|
||||||
|
</div>
|
||||||
|
</SidebarMenuItem>
|
||||||
|
);
|
||||||
|
})}
|
||||||
|
</SidebarMenu>
|
||||||
|
<ChannelRuntimeConfigDialog
|
||||||
|
provider={setupProvider}
|
||||||
|
open={setupProvider !== null}
|
||||||
|
submitting={configureMutation.isPending}
|
||||||
|
onOpenChange={(open) => {
|
||||||
|
if (!open) {
|
||||||
|
setSetupProvider(null);
|
||||||
|
}
|
||||||
|
}}
|
||||||
|
onSubmit={(provider, values) => {
|
||||||
|
void configureMutation
|
||||||
|
.mutateAsync({ provider: provider.provider, values })
|
||||||
|
.then(() => {
|
||||||
|
setSetupProvider(null);
|
||||||
|
toast.success(t.channels.connected);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error ? error.message : t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</SidebarGroup>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -11,7 +11,7 @@ import {
|
|||||||
} from "lucide-react";
|
} from "lucide-react";
|
||||||
import Link from "next/link";
|
import Link from "next/link";
|
||||||
import { useParams, usePathname, useRouter } from "next/navigation";
|
import { useParams, usePathname, useRouter } from "next/navigation";
|
||||||
import { useCallback, useState } from "react";
|
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||||
import { toast } from "sonner";
|
import { toast } from "sonner";
|
||||||
|
|
||||||
import { Button } from "@/components/ui/button";
|
import { Button } from "@/components/ui/button";
|
||||||
@@ -51,14 +51,20 @@ import {
|
|||||||
} from "@/core/threads/export";
|
} from "@/core/threads/export";
|
||||||
import {
|
import {
|
||||||
useDeleteThread,
|
useDeleteThread,
|
||||||
|
useInfiniteThreads,
|
||||||
useRenameThread,
|
useRenameThread,
|
||||||
useThreads,
|
|
||||||
} from "@/core/threads/hooks";
|
} from "@/core/threads/hooks";
|
||||||
import type { AgentThread, AgentThreadState } from "@/core/threads/types";
|
import type { AgentThread, AgentThreadState } from "@/core/threads/types";
|
||||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
import {
|
||||||
|
channelSourceOfThread,
|
||||||
|
pathOfThread,
|
||||||
|
titleOfThread,
|
||||||
|
} from "@/core/threads/utils";
|
||||||
import { env } from "@/env";
|
import { env } from "@/env";
|
||||||
import { isIMEComposing } from "@/lib/ime";
|
import { isIMEComposing } from "@/lib/ime";
|
||||||
|
|
||||||
|
import { ThreadChannelIcon } from "./thread-channel-source";
|
||||||
|
|
||||||
export function RecentChatList() {
|
export function RecentChatList() {
|
||||||
const { t } = useI18n();
|
const { t } = useI18n();
|
||||||
const router = useRouter();
|
const router = useRouter();
|
||||||
@@ -68,7 +74,35 @@ export function RecentChatList() {
|
|||||||
thread_id: string;
|
thread_id: string;
|
||||||
agent_name?: string;
|
agent_name?: string;
|
||||||
}>();
|
}>();
|
||||||
const { data: threads = [] } = useThreads();
|
const {
|
||||||
|
data: infiniteThreads,
|
||||||
|
fetchNextPage,
|
||||||
|
hasNextPage,
|
||||||
|
isFetchingNextPage,
|
||||||
|
} = useInfiniteThreads();
|
||||||
|
const threads = useMemo(
|
||||||
|
() => infiniteThreads?.pages.flat() ?? [],
|
||||||
|
[infiniteThreads],
|
||||||
|
);
|
||||||
|
|
||||||
|
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||||
|
useEffect(() => {
|
||||||
|
const element = sentinelRef.current;
|
||||||
|
if (!element || !hasNextPage) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const observer = new IntersectionObserver(
|
||||||
|
([entry]) => {
|
||||||
|
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||||
|
void fetchNextPage();
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{ rootMargin: "120px 0px 120px 0px" },
|
||||||
|
);
|
||||||
|
observer.observe(element);
|
||||||
|
return () => observer.disconnect();
|
||||||
|
}, [fetchNextPage, hasNextPage, isFetchingNextPage]);
|
||||||
|
|
||||||
const { mutate: deleteThread } = useDeleteThread();
|
const { mutate: deleteThread } = useDeleteThread();
|
||||||
const { mutate: renameThread } = useRenameThread();
|
const { mutate: renameThread } = useRenameThread();
|
||||||
|
|
||||||
@@ -182,6 +216,7 @@ export function RecentChatList() {
|
|||||||
<div className="flex w-full flex-col gap-1">
|
<div className="flex w-full flex-col gap-1">
|
||||||
{threads.map((thread) => {
|
{threads.map((thread) => {
|
||||||
const isActive = pathOfThread(thread) === pathname;
|
const isActive = pathOfThread(thread) === pathname;
|
||||||
|
const channelSource = channelSourceOfThread(thread);
|
||||||
return (
|
return (
|
||||||
<SidebarMenuItem
|
<SidebarMenuItem
|
||||||
key={thread.thread_id}
|
key={thread.thread_id}
|
||||||
@@ -190,10 +225,23 @@ export function RecentChatList() {
|
|||||||
<SidebarMenuButton isActive={isActive} asChild>
|
<SidebarMenuButton isActive={isActive} asChild>
|
||||||
<div>
|
<div>
|
||||||
<Link
|
<Link
|
||||||
className="text-muted-foreground block w-full whitespace-nowrap group-hover/side-menu-item:overflow-hidden"
|
className="text-muted-foreground flex min-w-0 items-center gap-1.5 pr-7 whitespace-nowrap group-hover/side-menu-item:overflow-hidden"
|
||||||
href={pathOfThread(thread)}
|
href={pathOfThread(thread)}
|
||||||
>
|
>
|
||||||
{titleOfThread(thread)}
|
<ThreadChannelIcon source={channelSource} />
|
||||||
|
<span className="min-w-0 truncate">
|
||||||
|
{titleOfThread(thread)}
|
||||||
|
</span>
|
||||||
|
{channelSource && (
|
||||||
|
<span
|
||||||
|
className="bg-muted text-muted-foreground ml-auto inline-flex h-5 max-w-14 shrink-0 items-center rounded-md px-1.5 text-[10px] font-medium"
|
||||||
|
title={`${channelSource.label} channel`}
|
||||||
|
>
|
||||||
|
<span className="truncate">
|
||||||
|
{channelSource.label}
|
||||||
|
</span>
|
||||||
|
</span>
|
||||||
|
)}
|
||||||
</Link>
|
</Link>
|
||||||
{env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY !== "true" && (
|
{env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY !== "true" && (
|
||||||
<DropdownMenu>
|
<DropdownMenu>
|
||||||
@@ -267,6 +315,28 @@ export function RecentChatList() {
|
|||||||
</SidebarMenuItem>
|
</SidebarMenuItem>
|
||||||
);
|
);
|
||||||
})}
|
})}
|
||||||
|
{hasNextPage && (
|
||||||
|
<>
|
||||||
|
<Button
|
||||||
|
variant="ghost"
|
||||||
|
size="sm"
|
||||||
|
className="mx-2 my-1 w-[calc(100%-1rem)] justify-center text-xs"
|
||||||
|
onClick={() => void fetchNextPage()}
|
||||||
|
disabled={isFetchingNextPage}
|
||||||
|
data-testid="recent-chat-list-load-more"
|
||||||
|
>
|
||||||
|
{isFetchingNextPage
|
||||||
|
? t.chats.loadingMore
|
||||||
|
: t.chats.loadOlderChats}
|
||||||
|
</Button>
|
||||||
|
<div
|
||||||
|
ref={sentinelRef}
|
||||||
|
aria-hidden="true"
|
||||||
|
className="h-px w-full"
|
||||||
|
data-testid="recent-chat-list-sentinel"
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
)}
|
||||||
</div>
|
</div>
|
||||||
</SidebarMenu>
|
</SidebarMenu>
|
||||||
</SidebarGroupContent>
|
</SidebarGroupContent>
|
||||||
|
|||||||
@@ -0,0 +1,364 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import {
|
||||||
|
AlertCircleIcon,
|
||||||
|
CheckCircle2Icon,
|
||||||
|
LoaderCircleIcon,
|
||||||
|
PlugIcon,
|
||||||
|
UnplugIcon,
|
||||||
|
} from "lucide-react";
|
||||||
|
import { useState } from "react";
|
||||||
|
import { toast } from "sonner";
|
||||||
|
|
||||||
|
import { Badge } from "@/components/ui/badge";
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import {
|
||||||
|
Item,
|
||||||
|
ItemActions,
|
||||||
|
ItemContent,
|
||||||
|
ItemDescription,
|
||||||
|
ItemMedia,
|
||||||
|
ItemTitle,
|
||||||
|
} from "@/components/ui/item";
|
||||||
|
import {
|
||||||
|
useConfigureChannelProvider,
|
||||||
|
useChannelConnections,
|
||||||
|
useChannelProviders,
|
||||||
|
useConnectChannelProvider,
|
||||||
|
useDisconnectChannelProvider,
|
||||||
|
} from "@/core/channels/hooks";
|
||||||
|
import {
|
||||||
|
closeConnectWindow,
|
||||||
|
openConnectUrl,
|
||||||
|
prepareConnectWindow,
|
||||||
|
} from "@/core/channels/open-connect-url";
|
||||||
|
import type { ChannelConnection, ChannelProvider } from "@/core/channels/types";
|
||||||
|
import { useI18n } from "@/core/i18n/hooks";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
import { ChannelProviderIcon } from "../channels/channel-provider-icon";
|
||||||
|
import { ChannelRuntimeConfigDialog } from "../channels/channel-runtime-config-dialog";
|
||||||
|
|
||||||
|
import { SettingsSection } from "./settings-section";
|
||||||
|
|
||||||
|
function getProviderDescription(
|
||||||
|
provider: ChannelProvider,
|
||||||
|
descriptions: Record<string, string>,
|
||||||
|
): string {
|
||||||
|
return descriptions[provider.provider] ?? provider.display_name;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getConnectionLabel(connection: ChannelConnection): string | null {
|
||||||
|
const account = connection.external_account_name;
|
||||||
|
const workspace = connection.workspace_name;
|
||||||
|
if (account && workspace) {
|
||||||
|
return `${account} · ${workspace}`;
|
||||||
|
}
|
||||||
|
return account ?? workspace ?? connection.external_account_id ?? null;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getStatusLabel(
|
||||||
|
provider: ChannelProvider,
|
||||||
|
connection: ChannelConnection | undefined,
|
||||||
|
t: ReturnType<typeof useI18n>["t"],
|
||||||
|
): string {
|
||||||
|
if (!provider.enabled) {
|
||||||
|
return t.channels.disabled;
|
||||||
|
}
|
||||||
|
if (!provider.configured) {
|
||||||
|
return t.channels.unconfigured;
|
||||||
|
}
|
||||||
|
if (provider.unavailable_reason) {
|
||||||
|
return t.channels.unavailableShort;
|
||||||
|
}
|
||||||
|
const status = connection?.status ?? provider.connection_status;
|
||||||
|
if (status === "connected") {
|
||||||
|
return t.channels.connected;
|
||||||
|
}
|
||||||
|
if (status === "pending") {
|
||||||
|
return t.channels.pending;
|
||||||
|
}
|
||||||
|
if (status === "revoked") {
|
||||||
|
return t.channels.revoked;
|
||||||
|
}
|
||||||
|
return t.channels.notConnected;
|
||||||
|
}
|
||||||
|
|
||||||
|
function getProviderUnavailableReason(
|
||||||
|
provider: ChannelProvider,
|
||||||
|
t: ReturnType<typeof useI18n>["t"],
|
||||||
|
): string | undefined {
|
||||||
|
if (provider.unavailable_reason) {
|
||||||
|
return provider.unavailable_reason;
|
||||||
|
}
|
||||||
|
if (!provider.enabled) {
|
||||||
|
return t.channels.disabled;
|
||||||
|
}
|
||||||
|
if (!provider.configured) {
|
||||||
|
return t.channels.unconfigured;
|
||||||
|
}
|
||||||
|
return provider.unavailable_reason ?? undefined;
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerNeedsRuntimeConfig(provider: ChannelProvider): boolean {
|
||||||
|
return (
|
||||||
|
provider.enabled &&
|
||||||
|
!provider.configured &&
|
||||||
|
(provider.credential_fields?.length ?? 0) > 0
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerCanEditRuntimeConfig(provider: ChannelProvider): boolean {
|
||||||
|
return provider.enabled && (provider.credential_fields?.length ?? 0) > 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
function ChannelProviderItem({
|
||||||
|
provider,
|
||||||
|
connection,
|
||||||
|
}: {
|
||||||
|
provider: ChannelProvider;
|
||||||
|
connection?: ChannelConnection;
|
||||||
|
}) {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const connectMutation = useConnectChannelProvider();
|
||||||
|
const configureMutation = useConfigureChannelProvider();
|
||||||
|
const disconnectProviderMutation = useDisconnectChannelProvider();
|
||||||
|
const [setupOpen, setSetupOpen] = useState(false);
|
||||||
|
const isConnected =
|
||||||
|
connection?.status === "connected" ||
|
||||||
|
provider.connection_status === "connected";
|
||||||
|
const canEditRuntimeConfig = providerCanEditRuntimeConfig(provider);
|
||||||
|
const canConnect =
|
||||||
|
(provider.connectable ?? (provider.enabled && provider.configured)) &&
|
||||||
|
!isConnected;
|
||||||
|
const isConnecting =
|
||||||
|
(connectMutation.isPending &&
|
||||||
|
connectMutation.variables === provider.provider) ||
|
||||||
|
(configureMutation.isPending &&
|
||||||
|
configureMutation.variables?.provider === provider.provider);
|
||||||
|
const isDisconnecting =
|
||||||
|
disconnectProviderMutation.isPending &&
|
||||||
|
disconnectProviderMutation.variables === provider.provider;
|
||||||
|
const connectionLabel = connection ? getConnectionLabel(connection) : null;
|
||||||
|
const statusLabel = getStatusLabel(provider, connection, t);
|
||||||
|
const unavailableReason = getProviderUnavailableReason(provider, t);
|
||||||
|
|
||||||
|
const startConnect = (
|
||||||
|
connectProvider: ChannelProvider,
|
||||||
|
preparedWindow?: Window | null,
|
||||||
|
) => {
|
||||||
|
const connectWindow =
|
||||||
|
preparedWindow !== undefined
|
||||||
|
? preparedWindow
|
||||||
|
: connectProvider.auth_mode === "deep_link"
|
||||||
|
? prepareConnectWindow()
|
||||||
|
: null;
|
||||||
|
void connectMutation
|
||||||
|
.mutateAsync(connectProvider.provider)
|
||||||
|
.then((result) => {
|
||||||
|
if (result.url) {
|
||||||
|
openConnectUrl(result.url, connectWindow);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
closeConnectWindow(connectWindow);
|
||||||
|
toast.success(result.instruction);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
closeConnectWindow(connectWindow);
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error ? error.message : t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<>
|
||||||
|
<Item variant="outline" className="w-full items-start">
|
||||||
|
<ItemMedia variant="icon" className="bg-background">
|
||||||
|
<ChannelProviderIcon
|
||||||
|
provider={provider.provider}
|
||||||
|
className="size-5"
|
||||||
|
/>
|
||||||
|
</ItemMedia>
|
||||||
|
<ItemContent className="min-w-0">
|
||||||
|
<ItemTitle className="w-full">
|
||||||
|
<span className="truncate">{provider.display_name}</span>
|
||||||
|
<Badge
|
||||||
|
variant={isConnected ? "default" : "outline"}
|
||||||
|
className={cn(!isConnected && "text-muted-foreground")}
|
||||||
|
>
|
||||||
|
{isConnected ? <CheckCircle2Icon /> : <AlertCircleIcon />}
|
||||||
|
{statusLabel}
|
||||||
|
</Badge>
|
||||||
|
</ItemTitle>
|
||||||
|
<ItemDescription className="line-clamp-none">
|
||||||
|
{getProviderDescription(provider, t.channels.descriptions)}
|
||||||
|
{connectionLabel
|
||||||
|
? ` ${t.channels.connectedAs(connectionLabel)}`
|
||||||
|
: ""}
|
||||||
|
{!isConnected && provider.unavailable_reason
|
||||||
|
? ` ${provider.unavailable_reason}`
|
||||||
|
: ""}
|
||||||
|
</ItemDescription>
|
||||||
|
</ItemContent>
|
||||||
|
<ItemActions className="ml-auto">
|
||||||
|
{isConnected ? (
|
||||||
|
<>
|
||||||
|
{canEditRuntimeConfig ? (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
disabled={isConnecting || isDisconnecting}
|
||||||
|
onClick={() => setSetupOpen(true)}
|
||||||
|
>
|
||||||
|
{isConnecting ? (
|
||||||
|
<LoaderCircleIcon className="animate-spin" />
|
||||||
|
) : (
|
||||||
|
<PlugIcon />
|
||||||
|
)}
|
||||||
|
{t.channels.modify}
|
||||||
|
</Button>
|
||||||
|
) : null}
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
variant="outline"
|
||||||
|
size="sm"
|
||||||
|
disabled={isDisconnecting}
|
||||||
|
onClick={() => {
|
||||||
|
void disconnectProviderMutation
|
||||||
|
.mutateAsync(provider.provider)
|
||||||
|
.then(() => {
|
||||||
|
toast.success(t.channels.revoked);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error
|
||||||
|
? error.message
|
||||||
|
: t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{isDisconnecting ? (
|
||||||
|
<LoaderCircleIcon className="animate-spin" />
|
||||||
|
) : (
|
||||||
|
<UnplugIcon />
|
||||||
|
)}
|
||||||
|
{t.channels.disconnect}
|
||||||
|
</Button>
|
||||||
|
</>
|
||||||
|
) : (
|
||||||
|
<Button
|
||||||
|
type="button"
|
||||||
|
size="sm"
|
||||||
|
disabled={isConnecting}
|
||||||
|
title={unavailableReason}
|
||||||
|
onClick={() => {
|
||||||
|
if (
|
||||||
|
providerNeedsRuntimeConfig(provider) ||
|
||||||
|
canEditRuntimeConfig
|
||||||
|
) {
|
||||||
|
setSetupOpen(true);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!canConnect) {
|
||||||
|
toast.error(unavailableReason ?? t.channels.unavailable);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
startConnect(provider);
|
||||||
|
}}
|
||||||
|
>
|
||||||
|
{isConnecting ? (
|
||||||
|
<LoaderCircleIcon className="animate-spin" />
|
||||||
|
) : (
|
||||||
|
<PlugIcon />
|
||||||
|
)}
|
||||||
|
{connection?.status === "revoked"
|
||||||
|
? t.channels.reconnect
|
||||||
|
: t.channels.connect}
|
||||||
|
</Button>
|
||||||
|
)}
|
||||||
|
</ItemActions>
|
||||||
|
</Item>
|
||||||
|
<ChannelRuntimeConfigDialog
|
||||||
|
provider={provider}
|
||||||
|
open={setupOpen}
|
||||||
|
submitting={configureMutation.isPending}
|
||||||
|
onOpenChange={setSetupOpen}
|
||||||
|
onSubmit={(submitProvider, values) => {
|
||||||
|
void configureMutation
|
||||||
|
.mutateAsync({ provider: submitProvider.provider, values })
|
||||||
|
.then(() => {
|
||||||
|
setSetupOpen(false);
|
||||||
|
toast.success(t.channels.connected);
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
toast.error(
|
||||||
|
error instanceof Error ? error.message : t.channels.unavailable,
|
||||||
|
);
|
||||||
|
});
|
||||||
|
}}
|
||||||
|
/>
|
||||||
|
</>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function ChannelsSettingsPage() {
|
||||||
|
const { t } = useI18n();
|
||||||
|
const {
|
||||||
|
enabled,
|
||||||
|
providers,
|
||||||
|
isLoading: providersLoading,
|
||||||
|
error: providersError,
|
||||||
|
} = useChannelProviders();
|
||||||
|
const {
|
||||||
|
connections,
|
||||||
|
isLoading: connectionsLoading,
|
||||||
|
error: connectionsError,
|
||||||
|
} = useChannelConnections();
|
||||||
|
const isLoading = providersLoading || connectionsLoading;
|
||||||
|
const error = providersError ?? connectionsError;
|
||||||
|
const visibleProviders = providers.filter((provider) => provider.enabled);
|
||||||
|
|
||||||
|
const connectionByProvider = new Map<string, ChannelConnection>();
|
||||||
|
for (const connection of connections) {
|
||||||
|
const existing = connectionByProvider.get(connection.provider);
|
||||||
|
if (!existing || connection.status === "connected") {
|
||||||
|
connectionByProvider.set(connection.provider, connection);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<SettingsSection
|
||||||
|
title={t.settings.channels.title}
|
||||||
|
description={t.settings.channels.description}
|
||||||
|
>
|
||||||
|
{isLoading ? (
|
||||||
|
<div className="text-muted-foreground text-sm">{t.common.loading}</div>
|
||||||
|
) : error ? (
|
||||||
|
<div className="text-destructive text-sm">{t.channels.unavailable}</div>
|
||||||
|
) : !enabled ? (
|
||||||
|
<div className="text-muted-foreground text-sm">
|
||||||
|
{t.settings.channels.disabled}
|
||||||
|
</div>
|
||||||
|
) : visibleProviders.length === 0 ? (
|
||||||
|
<div className="text-muted-foreground text-sm">
|
||||||
|
{t.settings.channels.disabled}
|
||||||
|
</div>
|
||||||
|
) : (
|
||||||
|
<div className="flex w-full flex-col gap-4">
|
||||||
|
{visibleProviders.map((provider) => (
|
||||||
|
<ChannelProviderItem
|
||||||
|
key={provider.provider}
|
||||||
|
provider={provider}
|
||||||
|
connection={connectionByProvider.get(provider.provider)}
|
||||||
|
/>
|
||||||
|
))}
|
||||||
|
</div>
|
||||||
|
)}
|
||||||
|
</SettingsSection>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import {
|
import {
|
||||||
BellIcon,
|
BellIcon,
|
||||||
|
CableIcon,
|
||||||
InfoIcon,
|
InfoIcon,
|
||||||
BrainIcon,
|
BrainIcon,
|
||||||
PaletteIcon,
|
PaletteIcon,
|
||||||
@@ -21,6 +22,7 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
|||||||
import { AboutSettingsPage } from "@/components/workspace/settings/about-settings-page";
|
import { AboutSettingsPage } from "@/components/workspace/settings/about-settings-page";
|
||||||
import { AccountSettingsPage } from "@/components/workspace/settings/account-settings-page";
|
import { AccountSettingsPage } from "@/components/workspace/settings/account-settings-page";
|
||||||
import { AppearanceSettingsPage } from "@/components/workspace/settings/appearance-settings-page";
|
import { AppearanceSettingsPage } from "@/components/workspace/settings/appearance-settings-page";
|
||||||
|
import { ChannelsSettingsPage } from "@/components/workspace/settings/channels-settings-page";
|
||||||
import { MemorySettingsPage } from "@/components/workspace/settings/memory-settings-page";
|
import { MemorySettingsPage } from "@/components/workspace/settings/memory-settings-page";
|
||||||
import { NotificationSettingsPage } from "@/components/workspace/settings/notification-settings-page";
|
import { NotificationSettingsPage } from "@/components/workspace/settings/notification-settings-page";
|
||||||
import { SkillSettingsPage } from "@/components/workspace/settings/skill-settings-page";
|
import { SkillSettingsPage } from "@/components/workspace/settings/skill-settings-page";
|
||||||
@@ -31,6 +33,7 @@ import { cn } from "@/lib/utils";
|
|||||||
type SettingsSection =
|
type SettingsSection =
|
||||||
| "account"
|
| "account"
|
||||||
| "appearance"
|
| "appearance"
|
||||||
|
| "channels"
|
||||||
| "memory"
|
| "memory"
|
||||||
| "tools"
|
| "tools"
|
||||||
| "skills"
|
| "skills"
|
||||||
@@ -72,6 +75,11 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
|||||||
label: t.settings.sections.notification,
|
label: t.settings.sections.notification,
|
||||||
icon: BellIcon,
|
icon: BellIcon,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: "channels",
|
||||||
|
label: t.settings.sections.channels,
|
||||||
|
icon: CableIcon,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: "memory",
|
id: "memory",
|
||||||
label: t.settings.sections.memory,
|
label: t.settings.sections.memory,
|
||||||
@@ -84,6 +92,7 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
|||||||
[
|
[
|
||||||
t.settings.sections.account,
|
t.settings.sections.account,
|
||||||
t.settings.sections.appearance,
|
t.settings.sections.appearance,
|
||||||
|
t.settings.sections.channels,
|
||||||
t.settings.sections.memory,
|
t.settings.sections.memory,
|
||||||
t.settings.sections.tools,
|
t.settings.sections.tools,
|
||||||
t.settings.sections.skills,
|
t.settings.sections.skills,
|
||||||
@@ -143,6 +152,7 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
|||||||
/>
|
/>
|
||||||
)}
|
)}
|
||||||
{activeSection === "notification" && <NotificationSettingsPage />}
|
{activeSection === "notification" && <NotificationSettingsPage />}
|
||||||
|
{activeSection === "channels" && <ChannelsSettingsPage />}
|
||||||
{activeSection === "about" && <AboutSettingsPage />}
|
{activeSection === "about" && <AboutSettingsPage />}
|
||||||
</div>
|
</div>
|
||||||
</ScrollArea>
|
</ScrollArea>
|
||||||
|
|||||||
@@ -0,0 +1,56 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { ChannelProviderIcon } from "@/components/workspace/channels/channel-provider-icon";
|
||||||
|
import type { ChannelThreadSource } from "@/core/threads/utils";
|
||||||
|
import { cn } from "@/lib/utils";
|
||||||
|
|
||||||
|
type ThreadChannelIconProps = {
|
||||||
|
source: ChannelThreadSource | null;
|
||||||
|
className?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ThreadChannelIcon({
|
||||||
|
source,
|
||||||
|
className,
|
||||||
|
}: ThreadChannelIconProps) {
|
||||||
|
if (!source) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
aria-label={`${source.label} channel`}
|
||||||
|
title={`${source.label} channel`}
|
||||||
|
className={cn("inline-flex shrink-0 items-center", className)}
|
||||||
|
>
|
||||||
|
<ChannelProviderIcon provider={source.provider} className="size-4" />
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
type ThreadChannelBadgeProps = {
|
||||||
|
source: ChannelThreadSource | null;
|
||||||
|
className?: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function ThreadChannelBadge({
|
||||||
|
source,
|
||||||
|
className,
|
||||||
|
}: ThreadChannelBadgeProps) {
|
||||||
|
if (!source) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<span
|
||||||
|
className={cn(
|
||||||
|
"bg-muted text-muted-foreground inline-flex h-6 max-w-32 items-center gap-1 rounded-md px-2 text-xs font-medium",
|
||||||
|
className,
|
||||||
|
)}
|
||||||
|
title={`${source.label} channel`}
|
||||||
|
>
|
||||||
|
<ChannelProviderIcon provider={source.provider} className="size-3.5" />
|
||||||
|
<span className="truncate">{source.label}</span>
|
||||||
|
</span>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -9,6 +9,7 @@ import {
|
|||||||
useSidebar,
|
useSidebar,
|
||||||
} from "@/components/ui/sidebar";
|
} from "@/components/ui/sidebar";
|
||||||
|
|
||||||
|
import { WorkspaceChannelsList } from "./channels/workspace-channels-list";
|
||||||
import { RecentChatList } from "./recent-chat-list";
|
import { RecentChatList } from "./recent-chat-list";
|
||||||
import { WorkspaceHeader } from "./workspace-header";
|
import { WorkspaceHeader } from "./workspace-header";
|
||||||
import { WorkspaceNavChatList } from "./workspace-nav-chat-list";
|
import { WorkspaceNavChatList } from "./workspace-nav-chat-list";
|
||||||
@@ -26,6 +27,7 @@ export function WorkspaceSidebar({
|
|||||||
</SidebarHeader>
|
</SidebarHeader>
|
||||||
<SidebarContent>
|
<SidebarContent>
|
||||||
<WorkspaceNavChatList />
|
<WorkspaceNavChatList />
|
||||||
|
<WorkspaceChannelsList />
|
||||||
{isSidebarOpen && <RecentChatList />}
|
{isSidebarOpen && <RecentChatList />}
|
||||||
</SidebarContent>
|
</SidebarContent>
|
||||||
<SidebarFooter>
|
<SidebarFooter>
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
import type { User } from "./types";
|
||||||
|
|
||||||
|
export const AUTH_DISABLED_USER: User = {
|
||||||
|
id: "default",
|
||||||
|
email: "default@test.local",
|
||||||
|
system_role: "admin",
|
||||||
|
needs_setup: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
const PRODUCTION_ENV_VALUES = new Set(["prod", "production"]);
|
||||||
|
|
||||||
|
function isExplicitProductionEnvironment() {
|
||||||
|
return ["DEER_FLOW_ENV", "ENVIRONMENT"].some((name) =>
|
||||||
|
PRODUCTION_ENV_VALUES.has((process.env[name] ?? "").trim().toLowerCase()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isAuthDisabledMode() {
|
||||||
|
return (
|
||||||
|
process.env.DEER_FLOW_AUTH_DISABLED === "1" &&
|
||||||
|
!isExplicitProductionEnvironment()
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -2,6 +2,7 @@ import { cookies } from "next/headers";
|
|||||||
|
|
||||||
import { isStaticWebsiteOnly } from "../static-mode";
|
import { isStaticWebsiteOnly } from "../static-mode";
|
||||||
|
|
||||||
|
import { AUTH_DISABLED_USER, isAuthDisabledMode } from "./auth-disabled-user";
|
||||||
import { getGatewayConfig } from "./gateway-config";
|
import { getGatewayConfig } from "./gateway-config";
|
||||||
import { STATIC_WEBSITE_USER } from "./static-user";
|
import { STATIC_WEBSITE_USER } from "./static-user";
|
||||||
import { type AuthResult, userSchema } from "./types";
|
import { type AuthResult, userSchema } from "./types";
|
||||||
@@ -20,15 +21,10 @@ export async function getServerSideUser(): Promise<AuthResult> {
|
|||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
if (process.env.DEER_FLOW_AUTH_DISABLED === "1") {
|
if (isAuthDisabledMode()) {
|
||||||
return {
|
return {
|
||||||
tag: "authenticated",
|
tag: "authenticated",
|
||||||
user: {
|
user: AUTH_DISABLED_USER,
|
||||||
id: "e2e-user",
|
|
||||||
email: "e2e@test.local",
|
|
||||||
system_role: "admin",
|
|
||||||
needs_setup: false,
|
|
||||||
},
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,117 @@
|
|||||||
|
import { fetch } from "@/core/api/fetcher";
|
||||||
|
import { getBackendBaseURL } from "@/core/config";
|
||||||
|
|
||||||
|
import type {
|
||||||
|
ChannelConnectResponse,
|
||||||
|
ChannelConnection,
|
||||||
|
ChannelConnectionsResponse,
|
||||||
|
ChannelProviderId,
|
||||||
|
ChannelProvider,
|
||||||
|
ChannelProvidersResponse,
|
||||||
|
ChannelRuntimeConfigValues,
|
||||||
|
} from "./types";
|
||||||
|
|
||||||
|
function channelsUrl(path: string): string {
|
||||||
|
return `${getBackendBaseURL()}/api/channels${path}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
async function throwChannelApiError(
|
||||||
|
response: Response,
|
||||||
|
fallback: string,
|
||||||
|
): Promise<never> {
|
||||||
|
const body = (await response.json().catch(() => ({}))) as {
|
||||||
|
detail?: unknown;
|
||||||
|
};
|
||||||
|
throw new Error(typeof body.detail === "string" ? body.detail : fallback);
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listChannelProviders(): Promise<ChannelProvidersResponse> {
|
||||||
|
const response = await fetch(channelsUrl("/providers"));
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to load channel providers: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json() as Promise<ChannelProvidersResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function listChannelConnections(): Promise<ChannelConnection[]> {
|
||||||
|
const response = await fetch(channelsUrl("/connections"));
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to load channel connections: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
const data = (await response.json()) as ChannelConnectionsResponse;
|
||||||
|
return data.connections;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function connectChannelProvider(
|
||||||
|
provider: ChannelProviderId,
|
||||||
|
): Promise<ChannelConnectResponse> {
|
||||||
|
const response = await fetch(
|
||||||
|
channelsUrl(`/${encodeURIComponent(provider)}/connect`),
|
||||||
|
{ method: "POST" },
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to connect ${provider}: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json() as Promise<ChannelConnectResponse>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function configureChannelProvider(
|
||||||
|
provider: ChannelProviderId,
|
||||||
|
values: ChannelRuntimeConfigValues,
|
||||||
|
): Promise<ChannelProvider> {
|
||||||
|
const response = await fetch(
|
||||||
|
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
|
||||||
|
{
|
||||||
|
method: "POST",
|
||||||
|
headers: { "Content-Type": "application/json" },
|
||||||
|
body: JSON.stringify({ values }),
|
||||||
|
},
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to configure ${provider}: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json() as Promise<ChannelProvider>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function disconnectChannelConnection(
|
||||||
|
connectionId: string,
|
||||||
|
): Promise<void> {
|
||||||
|
const response = await fetch(
|
||||||
|
channelsUrl(`/connections/${encodeURIComponent(connectionId)}`),
|
||||||
|
{ method: "DELETE" },
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to disconnect channel: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export async function disconnectChannelProvider(
|
||||||
|
provider: ChannelProviderId,
|
||||||
|
): Promise<ChannelProvider> {
|
||||||
|
const response = await fetch(
|
||||||
|
channelsUrl(`/${encodeURIComponent(provider)}/runtime-config`),
|
||||||
|
{ method: "DELETE" },
|
||||||
|
);
|
||||||
|
if (!response.ok) {
|
||||||
|
await throwChannelApiError(
|
||||||
|
response,
|
||||||
|
`Failed to disconnect ${provider}: ${response.statusText}`,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return response.json() as Promise<ChannelProvider>;
|
||||||
|
}
|
||||||
@@ -0,0 +1,96 @@
|
|||||||
|
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||||
|
|
||||||
|
import {
|
||||||
|
configureChannelProvider,
|
||||||
|
connectChannelProvider,
|
||||||
|
disconnectChannelConnection,
|
||||||
|
disconnectChannelProvider,
|
||||||
|
listChannelConnections,
|
||||||
|
listChannelProviders,
|
||||||
|
} from "./api";
|
||||||
|
import type { ChannelProviderId, ChannelRuntimeConfigValues } from "./types";
|
||||||
|
|
||||||
|
export const channelProviderQueryKey = ["channelProviders"] as const;
|
||||||
|
export const channelConnectionsQueryKey = ["channelConnections"] as const;
|
||||||
|
|
||||||
|
export function useChannelProviders() {
|
||||||
|
const { data, isLoading, error } = useQuery({
|
||||||
|
queryKey: channelProviderQueryKey,
|
||||||
|
queryFn: () => listChannelProviders(),
|
||||||
|
});
|
||||||
|
return {
|
||||||
|
enabled: data?.enabled ?? false,
|
||||||
|
providers: data?.providers ?? [],
|
||||||
|
isLoading,
|
||||||
|
error,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useChannelConnections() {
|
||||||
|
const { data, isLoading, error } = useQuery({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
queryFn: () => listChannelConnections(),
|
||||||
|
});
|
||||||
|
return { connections: data ?? [], isLoading, error };
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useConnectChannelProvider() {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (provider: ChannelProviderId) =>
|
||||||
|
connectChannelProvider(provider),
|
||||||
|
onSuccess: () => {
|
||||||
|
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useConfigureChannelProvider() {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: ({
|
||||||
|
provider,
|
||||||
|
values,
|
||||||
|
}: {
|
||||||
|
provider: ChannelProviderId;
|
||||||
|
values: ChannelRuntimeConfigValues;
|
||||||
|
}) => configureChannelProvider(provider, values),
|
||||||
|
onSuccess: () => {
|
||||||
|
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useDisconnectChannelConnection() {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (connectionId: string) =>
|
||||||
|
disconnectChannelConnection(connectionId),
|
||||||
|
onSuccess: () => {
|
||||||
|
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useDisconnectChannelProvider() {
|
||||||
|
const queryClient = useQueryClient();
|
||||||
|
return useMutation({
|
||||||
|
mutationFn: (provider: ChannelProviderId) =>
|
||||||
|
disconnectChannelProvider(provider),
|
||||||
|
onSuccess: () => {
|
||||||
|
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: channelConnectionsQueryKey,
|
||||||
|
});
|
||||||
|
},
|
||||||
|
});
|
||||||
|
}
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
export type ChannelConnectWindow = Window | null;
|
||||||
|
|
||||||
|
export function prepareConnectWindow(): ChannelConnectWindow {
|
||||||
|
const opened = window.open("about:blank", "_blank");
|
||||||
|
if (opened) {
|
||||||
|
opened.opener = null;
|
||||||
|
}
|
||||||
|
return opened;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function openConnectUrl(
|
||||||
|
url: string,
|
||||||
|
connectWindow: ChannelConnectWindow = prepareConnectWindow(),
|
||||||
|
) {
|
||||||
|
if (connectWindow && !connectWindow.closed) {
|
||||||
|
connectWindow.location.replace(url);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
window.location.assign(url);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function closeConnectWindow(connectWindow: ChannelConnectWindow) {
|
||||||
|
if (connectWindow && !connectWindow.closed) {
|
||||||
|
connectWindow.close();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
export type ChannelProviderId = "telegram" | "slack" | "discord" | string;
|
||||||
|
|
||||||
|
export interface ChannelCredentialField {
|
||||||
|
name: string;
|
||||||
|
label: string;
|
||||||
|
type: string;
|
||||||
|
required: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ChannelRuntimeConfigValues = Record<string, string>;
|
||||||
|
|
||||||
|
export interface ChannelProvider {
|
||||||
|
provider: ChannelProviderId;
|
||||||
|
display_name: string;
|
||||||
|
enabled: boolean;
|
||||||
|
configured: boolean;
|
||||||
|
connectable?: boolean;
|
||||||
|
unavailable_reason?: string | null;
|
||||||
|
auth_mode: string;
|
||||||
|
connection_status: string;
|
||||||
|
credential_fields: ChannelCredentialField[];
|
||||||
|
credential_values?: ChannelRuntimeConfigValues;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChannelProvidersResponse {
|
||||||
|
enabled: boolean;
|
||||||
|
providers: ChannelProvider[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChannelConnection {
|
||||||
|
id: string;
|
||||||
|
provider: ChannelProviderId;
|
||||||
|
status: string;
|
||||||
|
external_account_id?: string | null;
|
||||||
|
external_account_name?: string | null;
|
||||||
|
workspace_id?: string | null;
|
||||||
|
workspace_name?: string | null;
|
||||||
|
scopes: string[];
|
||||||
|
metadata: Record<string, unknown>;
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChannelConnectionsResponse {
|
||||||
|
connections: ChannelConnection[];
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface ChannelConnectResponse {
|
||||||
|
provider: ChannelProviderId;
|
||||||
|
mode: string;
|
||||||
|
url?: string | null;
|
||||||
|
code: string;
|
||||||
|
instruction: string;
|
||||||
|
expires_in: number;
|
||||||
|
}
|
||||||
@@ -170,6 +170,7 @@ export const enUS: Translations = {
|
|||||||
sidebar: {
|
sidebar: {
|
||||||
newChat: "New chat",
|
newChat: "New chat",
|
||||||
chats: "Chats",
|
chats: "Chats",
|
||||||
|
channels: "Channels",
|
||||||
recentChats: "Recent chats",
|
recentChats: "Recent chats",
|
||||||
demoChats: "Demo chats",
|
demoChats: "Demo chats",
|
||||||
agents: "Agents",
|
agents: "Agents",
|
||||||
@@ -252,6 +253,42 @@ export const enUS: Translations = {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: "Search chats",
|
searchChats: "Search chats",
|
||||||
|
loadMoreToSearch: "Load more to search older conversations",
|
||||||
|
loadingMore: "Loading more...",
|
||||||
|
loadOlderChats: "Load older chats",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Channels
|
||||||
|
channels: {
|
||||||
|
title: "Channels",
|
||||||
|
connect: "Connect",
|
||||||
|
modify: "Modify",
|
||||||
|
reconnect: "Reconnect",
|
||||||
|
disconnect: "Disconnect",
|
||||||
|
connected: "Connected",
|
||||||
|
notConnected: "Not connected",
|
||||||
|
pending: "Pending",
|
||||||
|
revoked: "Disconnected",
|
||||||
|
disabled: "Disabled",
|
||||||
|
unconfigured: "Not configured",
|
||||||
|
unavailable: "Channel connections are unavailable right now.",
|
||||||
|
unavailableShort: "Unavailable",
|
||||||
|
setupTitle: (name: string) => `Connect ${name}`,
|
||||||
|
setupEditTitle: (name: string) => `Modify ${name}`,
|
||||||
|
setupDescription:
|
||||||
|
"Enter the values needed by this server process. They are not written to config.yaml.",
|
||||||
|
saveAndConnect: "Save and connect",
|
||||||
|
saveChanges: "Save changes",
|
||||||
|
descriptions: {
|
||||||
|
telegram: "Telegram direct messages through your DeerFlow bot.",
|
||||||
|
slack: "Slack workspace messages and mentions.",
|
||||||
|
discord: "Discord server messages through your DeerFlow bot.",
|
||||||
|
feishu: "Feishu and Lark messages through your DeerFlow app.",
|
||||||
|
dingtalk: "DingTalk Stream Push messages through your DeerFlow bot.",
|
||||||
|
wechat: "WeChat iLink messages through your DeerFlow bot.",
|
||||||
|
wecom: "WeCom messages through your DeerFlow AI bot.",
|
||||||
|
},
|
||||||
|
connectedAs: (name: string) => `Connected as ${name}.`,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
@@ -354,6 +391,7 @@ export const enUS: Translations = {
|
|||||||
sections: {
|
sections: {
|
||||||
account: "Account",
|
account: "Account",
|
||||||
appearance: "Appearance",
|
appearance: "Appearance",
|
||||||
|
channels: "Channels",
|
||||||
memory: "Memory",
|
memory: "Memory",
|
||||||
tools: "Tools",
|
tools: "Tools",
|
||||||
skills: "Skills",
|
skills: "Skills",
|
||||||
@@ -456,6 +494,13 @@ export const enUS: Translations = {
|
|||||||
title: "Tools",
|
title: "Tools",
|
||||||
description: "Manage the configuration and enabled status of MCP tools.",
|
description: "Manage the configuration and enabled status of MCP tools.",
|
||||||
},
|
},
|
||||||
|
channels: {
|
||||||
|
title: "Channels",
|
||||||
|
description:
|
||||||
|
"Connect IM accounts that can send messages to DeerFlow from outside the browser.",
|
||||||
|
disabled:
|
||||||
|
"Channel connections are not enabled on this server. Ask an administrator to enable channel_connections.",
|
||||||
|
},
|
||||||
skills: {
|
skills: {
|
||||||
title: "Agent Skills",
|
title: "Agent Skills",
|
||||||
description:
|
description:
|
||||||
|
|||||||
@@ -117,6 +117,7 @@ export interface Translations {
|
|||||||
chats: string;
|
chats: string;
|
||||||
demoChats: string;
|
demoChats: string;
|
||||||
agents: string;
|
agents: string;
|
||||||
|
channels: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Agents
|
// Agents
|
||||||
@@ -183,6 +184,33 @@ export interface Translations {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: string;
|
searchChats: string;
|
||||||
|
loadMoreToSearch: string;
|
||||||
|
loadingMore: string;
|
||||||
|
loadOlderChats: string;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Channels
|
||||||
|
channels: {
|
||||||
|
title: string;
|
||||||
|
connect: string;
|
||||||
|
modify: string;
|
||||||
|
reconnect: string;
|
||||||
|
disconnect: string;
|
||||||
|
connected: string;
|
||||||
|
notConnected: string;
|
||||||
|
pending: string;
|
||||||
|
revoked: string;
|
||||||
|
disabled: string;
|
||||||
|
unconfigured: string;
|
||||||
|
unavailable: string;
|
||||||
|
unavailableShort: string;
|
||||||
|
setupTitle: (name: string) => string;
|
||||||
|
setupEditTitle: (name: string) => string;
|
||||||
|
setupDescription: string;
|
||||||
|
saveAndConnect: string;
|
||||||
|
saveChanges: string;
|
||||||
|
descriptions: Record<string, string>;
|
||||||
|
connectedAs: (name: string) => string;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
@@ -281,6 +309,7 @@ export interface Translations {
|
|||||||
sections: {
|
sections: {
|
||||||
account: string;
|
account: string;
|
||||||
appearance: string;
|
appearance: string;
|
||||||
|
channels: string;
|
||||||
memory: string;
|
memory: string;
|
||||||
tools: string;
|
tools: string;
|
||||||
skills: string;
|
skills: string;
|
||||||
@@ -376,6 +405,11 @@ export interface Translations {
|
|||||||
title: string;
|
title: string;
|
||||||
description: string;
|
description: string;
|
||||||
};
|
};
|
||||||
|
channels: {
|
||||||
|
title: string;
|
||||||
|
description: string;
|
||||||
|
disabled: string;
|
||||||
|
};
|
||||||
skills: {
|
skills: {
|
||||||
title: string;
|
title: string;
|
||||||
description: string;
|
description: string;
|
||||||
|
|||||||
@@ -164,6 +164,7 @@ export const zhCN: Translations = {
|
|||||||
sidebar: {
|
sidebar: {
|
||||||
newChat: "新对话",
|
newChat: "新对话",
|
||||||
chats: "对话",
|
chats: "对话",
|
||||||
|
channels: "渠道",
|
||||||
recentChats: "最近的对话",
|
recentChats: "最近的对话",
|
||||||
demoChats: "演示对话",
|
demoChats: "演示对话",
|
||||||
agents: "智能体",
|
agents: "智能体",
|
||||||
@@ -240,6 +241,42 @@ export const zhCN: Translations = {
|
|||||||
// Chats
|
// Chats
|
||||||
chats: {
|
chats: {
|
||||||
searchChats: "搜索对话",
|
searchChats: "搜索对话",
|
||||||
|
loadMoreToSearch: "加载更多以搜索更早的对话",
|
||||||
|
loadingMore: "正在加载...",
|
||||||
|
loadOlderChats: "加载更早的对话",
|
||||||
|
},
|
||||||
|
|
||||||
|
// Channels
|
||||||
|
channels: {
|
||||||
|
title: "渠道",
|
||||||
|
connect: "连接",
|
||||||
|
modify: "修改",
|
||||||
|
reconnect: "重新连接",
|
||||||
|
disconnect: "断开连接",
|
||||||
|
connected: "已连接",
|
||||||
|
notConnected: "未连接",
|
||||||
|
pending: "待完成",
|
||||||
|
revoked: "已断开",
|
||||||
|
disabled: "已停用",
|
||||||
|
unconfigured: "未配置",
|
||||||
|
unavailable: "当前无法使用渠道连接。",
|
||||||
|
unavailableShort: "不可用",
|
||||||
|
setupTitle: (name: string) => `连接 ${name}`,
|
||||||
|
setupEditTitle: (name: string) => `修改 ${name}`,
|
||||||
|
setupDescription:
|
||||||
|
"填写当前服务进程需要的配置值。这些内容不会写入 config.yaml。",
|
||||||
|
saveAndConnect: "保存并连接",
|
||||||
|
saveChanges: "保存修改",
|
||||||
|
descriptions: {
|
||||||
|
telegram: "通过 DeerFlow Bot 接收 Telegram 私聊消息。",
|
||||||
|
slack: "接收 Slack 工作区消息和提及。",
|
||||||
|
discord: "通过 DeerFlow Bot 接收 Discord 服务器消息。",
|
||||||
|
feishu: "通过 DeerFlow 应用接收飞书和 Lark 消息。",
|
||||||
|
dingtalk: "通过 DeerFlow Bot 接收钉钉 Stream Push 消息。",
|
||||||
|
wechat: "通过 DeerFlow Bot 接收微信 iLink 消息。",
|
||||||
|
wecom: "通过 DeerFlow AI Bot 接收企业微信消息。",
|
||||||
|
},
|
||||||
|
connectedAs: (name: string) => `已连接为 ${name}。`,
|
||||||
},
|
},
|
||||||
|
|
||||||
// Page titles (document title)
|
// Page titles (document title)
|
||||||
@@ -338,6 +375,7 @@ export const zhCN: Translations = {
|
|||||||
sections: {
|
sections: {
|
||||||
account: "账号",
|
account: "账号",
|
||||||
appearance: "外观",
|
appearance: "外观",
|
||||||
|
channels: "渠道",
|
||||||
memory: "记忆",
|
memory: "记忆",
|
||||||
tools: "工具",
|
tools: "工具",
|
||||||
skills: "技能",
|
skills: "技能",
|
||||||
@@ -437,6 +475,12 @@ export const zhCN: Translations = {
|
|||||||
title: "工具",
|
title: "工具",
|
||||||
description: "管理 MCP 工具的配置和启用状态。",
|
description: "管理 MCP 工具的配置和启用状态。",
|
||||||
},
|
},
|
||||||
|
channels: {
|
||||||
|
title: "渠道",
|
||||||
|
description: "连接可在浏览器外向 DeerFlow 发送消息的即时通讯账号。",
|
||||||
|
disabled:
|
||||||
|
"当前服务器未启用渠道连接。请联系管理员开启 channel_connections。",
|
||||||
|
},
|
||||||
skills: {
|
skills: {
|
||||||
title: "技能",
|
title: "技能",
|
||||||
description: "管理 Agent Skill 配置和启用状态。",
|
description: "管理 Agent Skill 配置和启用状态。",
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
|||||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||||
import {
|
import {
|
||||||
type QueryClient,
|
type QueryClient,
|
||||||
|
type InfiniteData,
|
||||||
|
useInfiniteQuery,
|
||||||
useMutation,
|
useMutation,
|
||||||
useQuery,
|
useQuery,
|
||||||
useQueryClient,
|
useQueryClient,
|
||||||
@@ -24,6 +26,11 @@ import type { UploadedFileInfo } from "../uploads";
|
|||||||
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
import { promptInputFilePartToFile, uploadFiles } from "../uploads";
|
||||||
|
|
||||||
import { fetchThreadTokenUsage } from "./api";
|
import { fetchThreadTokenUsage } from "./api";
|
||||||
|
import {
|
||||||
|
buildThreadsSearchQueryOptions,
|
||||||
|
DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
|
type ThreadSearchParams,
|
||||||
|
} from "./thread-search-query";
|
||||||
import { threadTokenUsageQueryKey } from "./token-usage";
|
import { threadTokenUsageQueryKey } from "./token-usage";
|
||||||
import type {
|
import type {
|
||||||
AgentThread,
|
AgentThread,
|
||||||
@@ -311,6 +318,56 @@ export function upsertThreadInSearchCache(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function upsertThreadInInfiniteCache(
|
||||||
|
queryClient: QueryClient,
|
||||||
|
thread: AgentThread,
|
||||||
|
) {
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) => {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
|
||||||
|
const merged = oldData.pages.map((page) =>
|
||||||
|
page.map((t) =>
|
||||||
|
t.thread_id === thread.thread_id
|
||||||
|
? {
|
||||||
|
...thread,
|
||||||
|
...t,
|
||||||
|
metadata: {
|
||||||
|
...(thread.metadata ?? {}),
|
||||||
|
...(t.metadata ?? {}),
|
||||||
|
},
|
||||||
|
values: {
|
||||||
|
...thread.values,
|
||||||
|
...t.values,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
const exists = merged.some((page) =>
|
||||||
|
page.some((t) => t.thread_id === thread.thread_id),
|
||||||
|
);
|
||||||
|
if (exists) {
|
||||||
|
return { ...oldData, pages: merged };
|
||||||
|
}
|
||||||
|
|
||||||
|
const firstPage = merged[0] ?? [];
|
||||||
|
const restPages = merged.slice(1);
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: [[thread, ...firstPage], ...restPages],
|
||||||
|
};
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
function getStreamErrorMessage(error: unknown): string {
|
function getStreamErrorMessage(error: unknown): string {
|
||||||
if (typeof error === "string" && error.trim()) {
|
if (typeof error === "string" && error.trim()) {
|
||||||
return error;
|
return error;
|
||||||
@@ -364,7 +421,7 @@ export function useThreadStream({
|
|||||||
loadMore: loadMoreHistory,
|
loadMore: loadMoreHistory,
|
||||||
loading: isHistoryLoading,
|
loading: isHistoryLoading,
|
||||||
appendMessages,
|
appendMessages,
|
||||||
} = useThreadHistory(onStreamThreadId ?? "");
|
} = useThreadHistory(onStreamThreadId ?? "", { enabled: !isMock });
|
||||||
|
|
||||||
// Keep listeners ref updated with latest callbacks
|
// Keep listeners ref updated with latest callbacks
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
@@ -417,6 +474,19 @@ export function useThreadStream({
|
|||||||
},
|
},
|
||||||
interrupts: {},
|
interrupts: {},
|
||||||
});
|
});
|
||||||
|
upsertThreadInInfiniteCache(queryClient, {
|
||||||
|
thread_id: meta.thread_id,
|
||||||
|
created_at: now,
|
||||||
|
updated_at: now,
|
||||||
|
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
|
||||||
|
status: "busy",
|
||||||
|
values: {
|
||||||
|
title: t.pages.newChat,
|
||||||
|
messages: [],
|
||||||
|
artifacts: [],
|
||||||
|
},
|
||||||
|
interrupts: {},
|
||||||
|
});
|
||||||
if (context.agent_name && !isMock) {
|
if (context.agent_name && !isMock) {
|
||||||
void getAPIClient()
|
void getAPIClient()
|
||||||
.threads.update(meta.thread_id, {
|
.threads.update(meta.thread_id, {
|
||||||
@@ -488,6 +558,27 @@ export function useThreadStream({
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
const nextTitle: string = update.title;
|
||||||
|
void queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
mapInfiniteThreadsCache(
|
||||||
|
oldData,
|
||||||
|
(t): AgentThread =>
|
||||||
|
t.thread_id === threadIdRef.current
|
||||||
|
? {
|
||||||
|
...t,
|
||||||
|
values: {
|
||||||
|
...t.values,
|
||||||
|
title: nextTitle,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -542,6 +633,9 @@ export function useThreadStream({
|
|||||||
.filter((id): id is string => Boolean(id)),
|
.filter((id): id is string => Boolean(id)),
|
||||||
);
|
);
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
if (threadIdRef.current && !isMock) {
|
if (threadIdRef.current && !isMock) {
|
||||||
void queryClient.invalidateQueries({
|
void queryClient.invalidateQueries({
|
||||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||||
@@ -801,6 +895,9 @@ export function useThreadStream({
|
|||||||
},
|
},
|
||||||
);
|
);
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
setOptimisticMessages([]);
|
setOptimisticMessages([]);
|
||||||
setIsUploading(false);
|
setIsUploading(false);
|
||||||
@@ -854,8 +951,15 @@ export function useThreadStream({
|
|||||||
} as const;
|
} as const;
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useThreadHistory(threadId: string) {
|
type ThreadHistoryOptions = {
|
||||||
const runs = useThreadRuns(threadId);
|
enabled?: boolean;
|
||||||
|
};
|
||||||
|
|
||||||
|
export function useThreadHistory(
|
||||||
|
threadId: string,
|
||||||
|
{ enabled = true }: ThreadHistoryOptions = {},
|
||||||
|
) {
|
||||||
|
const runs = useThreadRuns(threadId, { enabled });
|
||||||
const threadIdRef = useRef(threadId);
|
const threadIdRef = useRef(threadId);
|
||||||
const runsRef = useRef(runs.data ?? []);
|
const runsRef = useRef(runs.data ?? []);
|
||||||
const indexRef = useRef(-1);
|
const indexRef = useRef(-1);
|
||||||
@@ -864,10 +968,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
const loadingRunIdRef = useRef<string | null>(null);
|
const loadingRunIdRef = useRef<string | null>(null);
|
||||||
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
||||||
const runBeforeSeqRef = useRef<Map<string, number>>(new Map());
|
const runBeforeSeqRef = useRef<Map<string, number>>(new Map());
|
||||||
|
const loadGenerationRef = useRef(0);
|
||||||
const [loading, setLoading] = useState(false);
|
const [loading, setLoading] = useState(false);
|
||||||
const [messages, setMessages] = useState<Message[]>([]);
|
const [messages, setMessages] = useState<Message[]>([]);
|
||||||
|
|
||||||
const loadMessages = useCallback(async () => {
|
const loadMessages = useCallback(async () => {
|
||||||
|
if (!enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const loadGeneration = loadGenerationRef.current;
|
||||||
if (loadingRef.current) {
|
if (loadingRef.current) {
|
||||||
const pendingRunIndex = findLatestUnloadedRunIndex(
|
const pendingRunIndex = findLatestUnloadedRunIndex(
|
||||||
runsRef.current,
|
runsRef.current,
|
||||||
@@ -921,12 +1030,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
}).then((res) => {
|
}).then((res) => {
|
||||||
return res.json();
|
return res.json();
|
||||||
});
|
});
|
||||||
|
if (
|
||||||
|
loadGenerationRef.current !== loadGeneration ||
|
||||||
|
threadIdRef.current !== requestThreadId
|
||||||
|
) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
const _messages = result.data
|
const _messages = result.data
|
||||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||||
.map((m) => m.content);
|
.map((m) => m.content);
|
||||||
if (threadIdRef.current !== requestThreadId) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
setMessages((prev) =>
|
setMessages((prev) =>
|
||||||
dedupeMessagesByIdentity([..._messages, ...prev]),
|
dedupeMessagesByIdentity([..._messages, ...prev]),
|
||||||
);
|
);
|
||||||
@@ -961,16 +1073,19 @@ export function useThreadHistory(threadId: string) {
|
|||||||
} catch (err) {
|
} catch (err) {
|
||||||
console.error(err);
|
console.error(err);
|
||||||
} finally {
|
} finally {
|
||||||
loadingRef.current = false;
|
if (loadGenerationRef.current === loadGeneration) {
|
||||||
loadingRunIdRef.current = null;
|
loadingRef.current = false;
|
||||||
setLoading(false);
|
loadingRunIdRef.current = null;
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}, []);
|
}, [enabled]);
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const threadChanged = threadIdRef.current !== threadId;
|
const threadChanged = threadIdRef.current !== threadId;
|
||||||
threadIdRef.current = threadId;
|
threadIdRef.current = threadId;
|
||||||
|
|
||||||
if (threadChanged) {
|
if (!enabled || threadChanged) {
|
||||||
|
loadGenerationRef.current += 1;
|
||||||
runsRef.current = [];
|
runsRef.current = [];
|
||||||
indexRef.current = -1;
|
indexRef.current = -1;
|
||||||
pendingLoadRef.current = false;
|
pendingLoadRef.current = false;
|
||||||
@@ -982,6 +1097,10 @@ export function useThreadHistory(threadId: string) {
|
|||||||
setMessages([]);
|
setMessages([]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (!enabled) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
if (runs.data && runs.data.length > 0) {
|
if (runs.data && runs.data.length > 0) {
|
||||||
runsRef.current = runs.data ?? [];
|
runsRef.current = runs.data ?? [];
|
||||||
indexRef.current = findLatestUnloadedRunIndex(
|
indexRef.current = findLatestUnloadedRunIndex(
|
||||||
@@ -992,14 +1111,15 @@ export function useThreadHistory(threadId: string) {
|
|||||||
loadMessages().catch(() => {
|
loadMessages().catch(() => {
|
||||||
toast.error("Failed to load thread history.");
|
toast.error("Failed to load thread history.");
|
||||||
});
|
});
|
||||||
}, [threadId, runs.data, loadMessages]);
|
}, [enabled, threadId, runs.data, loadMessages]);
|
||||||
|
|
||||||
const appendMessages = useCallback((_messages: Message[]) => {
|
const appendMessages = useCallback((_messages: Message[]) => {
|
||||||
setMessages((prev) => {
|
setMessages((prev) => {
|
||||||
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
||||||
});
|
});
|
||||||
}, []);
|
}, []);
|
||||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
const hasMore =
|
||||||
|
enabled && Boolean(threadId) && (indexRef.current >= 0 || !runs.data);
|
||||||
return {
|
return {
|
||||||
runs: runs.data,
|
runs: runs.data,
|
||||||
messages,
|
messages,
|
||||||
@@ -1011,73 +1131,98 @@ export function useThreadHistory(threadId: string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function useThreads(
|
export function useThreads(
|
||||||
params: Parameters<ThreadsClient["search"]>[0] = {
|
params: ThreadSearchParams = DEFAULT_THREAD_SEARCH_PARAMS,
|
||||||
limit: 50,
|
) {
|
||||||
|
const apiClient = getAPIClient();
|
||||||
|
return useQuery<AgentThread[]>({
|
||||||
|
...buildThreadsSearchQueryOptions(apiClient, params),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export const INFINITE_THREADS_PAGE_SIZE = 50;
|
||||||
|
|
||||||
|
export const INFINITE_THREADS_QUERY_KEY_PREFIX = [
|
||||||
|
"threads",
|
||||||
|
"searchInfinite",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
type InfiniteThreadsParams = Omit<
|
||||||
|
Parameters<ThreadsClient["search"]>[0],
|
||||||
|
"limit" | "offset"
|
||||||
|
>;
|
||||||
|
|
||||||
|
export function getInfiniteThreadsNextPageParam(
|
||||||
|
lastPage: AgentThread[],
|
||||||
|
allPages: AgentThread[][],
|
||||||
|
pageSize: number = INFINITE_THREADS_PAGE_SIZE,
|
||||||
|
): number | undefined {
|
||||||
|
if (lastPage.length < pageSize) {
|
||||||
|
return undefined;
|
||||||
|
}
|
||||||
|
return allPages.reduce((sum, page) => sum + page.length, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function mapInfiniteThreadsCache(
|
||||||
|
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||||
|
mapper: (thread: AgentThread) => AgentThread,
|
||||||
|
): InfiniteData<AgentThread[]> | undefined {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: oldData.pages.map((page) => page.map(mapper)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function filterInfiniteThreadsCache(
|
||||||
|
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||||
|
predicate: (thread: AgentThread) => boolean,
|
||||||
|
): InfiniteData<AgentThread[]> | undefined {
|
||||||
|
if (!oldData) {
|
||||||
|
return oldData;
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
...oldData,
|
||||||
|
pages: oldData.pages.map((page) => page.filter(predicate)),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
export function useInfiniteThreads(
|
||||||
|
params: InfiniteThreadsParams = {
|
||||||
sortBy: "updated_at",
|
sortBy: "updated_at",
|
||||||
sortOrder: "desc",
|
sortOrder: "desc",
|
||||||
select: ["thread_id", "updated_at", "values", "metadata"],
|
select: ["thread_id", "updated_at", "values", "metadata"],
|
||||||
},
|
},
|
||||||
) {
|
) {
|
||||||
const apiClient = getAPIClient();
|
const apiClient = getAPIClient();
|
||||||
return useQuery<AgentThread[]>({
|
return useInfiniteQuery<
|
||||||
queryKey: ["threads", "search", params],
|
AgentThread[],
|
||||||
queryFn: async () => {
|
Error,
|
||||||
const maxResults = params.limit;
|
InfiniteData<AgentThread[]>,
|
||||||
const initialOffset = params.offset ?? 0;
|
readonly unknown[],
|
||||||
const DEFAULT_PAGE_SIZE = 50;
|
number
|
||||||
|
>({
|
||||||
// Preserve prior semantics: if a non-positive limit is explicitly provided,
|
queryKey: [...INFINITE_THREADS_QUERY_KEY_PREFIX, params],
|
||||||
// delegate to a single search call with the original parameters.
|
initialPageParam: 0,
|
||||||
if (maxResults !== undefined && maxResults <= 0) {
|
queryFn: async ({ pageParam }) => {
|
||||||
const response =
|
const response = (await apiClient.threads.search<AgentThreadState>({
|
||||||
await apiClient.threads.search<AgentThreadState>(params);
|
...params,
|
||||||
return response as AgentThread[];
|
limit: INFINITE_THREADS_PAGE_SIZE,
|
||||||
}
|
offset: pageParam,
|
||||||
|
})) as AgentThread[];
|
||||||
const pageSize =
|
return response;
|
||||||
typeof maxResults === "number" && maxResults > 0
|
|
||||||
? Math.min(DEFAULT_PAGE_SIZE, maxResults)
|
|
||||||
: DEFAULT_PAGE_SIZE;
|
|
||||||
|
|
||||||
const threads: AgentThread[] = [];
|
|
||||||
let offset = initialOffset;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
if (typeof maxResults === "number" && threads.length >= maxResults) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const currentLimit =
|
|
||||||
typeof maxResults === "number"
|
|
||||||
? Math.min(pageSize, maxResults - threads.length)
|
|
||||||
: pageSize;
|
|
||||||
|
|
||||||
if (typeof maxResults === "number" && currentLimit <= 0) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = (await apiClient.threads.search<AgentThreadState>({
|
|
||||||
...params,
|
|
||||||
limit: currentLimit,
|
|
||||||
offset,
|
|
||||||
})) as AgentThread[];
|
|
||||||
|
|
||||||
threads.push(...response);
|
|
||||||
|
|
||||||
if (response.length < currentLimit) {
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
offset += response.length;
|
|
||||||
}
|
|
||||||
|
|
||||||
return threads;
|
|
||||||
},
|
},
|
||||||
|
getNextPageParam: (lastPage, allPages) =>
|
||||||
|
getInfiniteThreadsNextPageParam(lastPage, allPages),
|
||||||
refetchOnWindowFocus: false,
|
refetchOnWindowFocus: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useThreadRuns(threadId?: string) {
|
export function useThreadRuns(
|
||||||
|
threadId?: string,
|
||||||
|
{ enabled = true }: { enabled?: boolean } = {},
|
||||||
|
) {
|
||||||
const apiClient = getAPIClient();
|
const apiClient = getAPIClient();
|
||||||
return useQuery<Run[]>({
|
return useQuery<Run[]>({
|
||||||
queryKey: ["thread", threadId],
|
queryKey: ["thread", threadId],
|
||||||
@@ -1088,6 +1233,7 @@ export function useThreadRuns(threadId?: string) {
|
|||||||
const response = await apiClient.runs.list(threadId);
|
const response = await apiClient.runs.list(threadId);
|
||||||
return response;
|
return response;
|
||||||
},
|
},
|
||||||
|
enabled: enabled && Boolean(threadId),
|
||||||
refetchOnWindowFocus: false,
|
refetchOnWindowFocus: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -1156,9 +1302,21 @@ export function useDeleteThread() {
|
|||||||
return oldData.filter((t) => t.thread_id !== threadId);
|
return oldData.filter((t) => t.thread_id !== threadId);
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
filterInfiniteThreadsCache(oldData, (t) => t.thread_id !== threadId),
|
||||||
|
);
|
||||||
},
|
},
|
||||||
|
|
||||||
onSettled() {
|
onSettled() {
|
||||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||||
|
void queryClient.invalidateQueries({
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
});
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
@@ -1199,6 +1357,24 @@ export function useRenameThread() {
|
|||||||
});
|
});
|
||||||
},
|
},
|
||||||
);
|
);
|
||||||
|
queryClient.setQueriesData(
|
||||||
|
{
|
||||||
|
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||||
|
exact: false,
|
||||||
|
},
|
||||||
|
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||||
|
mapInfiniteThreadsCache(oldData, (t) =>
|
||||||
|
t.thread_id === threadId
|
||||||
|
? {
|
||||||
|
...t,
|
||||||
|
values: {
|
||||||
|
...t.values,
|
||||||
|
title,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
: t,
|
||||||
|
),
|
||||||
|
);
|
||||||
},
|
},
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user