mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2d5f0787de | |||
| 5819bd8a59 | |||
| b3c2cc42cf | |||
| 167ef4512f | |||
| ba9cc5e972 |
@@ -10,7 +10,7 @@ permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
lint-backend:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
@@ -343,8 +343,6 @@ See the [MCP Server Guide](backend/docs/MCP_SERVER.md) for detailed instructions
|
||||
|
||||
DeerFlow supports receiving tasks from messaging apps. Channels auto-start when configured — no public IP required for any of them.
|
||||
|
||||
DeerFlow can also expose user-owned IM channel connections in the workspace UI. When `channel_connections` is enabled, logged-in users can bind Telegram, Slack, or Discord from the sidebar / Settings > Channels. It reuses the existing outbound `channels.*` transports, so no public IP or provider callback URL is required. Incoming IM messages then run under the connected DeerFlow user account. See [IM Channel Connections](backend/docs/IM_CHANNEL_CONNECTIONS.md) for setup and security notes.
|
||||
|
||||
| Channel | Transport | Difficulty |
|
||||
|---------|-----------|------------|
|
||||
| Telegram | Bot API (long-polling) | Easy |
|
||||
|
||||
+17
-22
@@ -369,7 +369,8 @@ Proxied through nginx: `/api/langgraph/*` → Gateway LangGraph-compatible runti
|
||||
|
||||
### IM Channels System (`app/channels/`)
|
||||
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
Bridges external messaging platforms (Feishu, Slack, Telegram, DingTalk) to the DeerFlow agent via Gateway's LangGraph-compatible API.
|
||||
|
||||
|
||||
**Architecture**: Channels communicate with Gateway through the `langgraph-sdk` HTTP client (same as the frontend), ensuring threads are created and managed server-side. The internal SDK client injects process-local internal auth plus a matching CSRF cookie/header pair so Gateway accepts state-changing thread/run requests from channel workers without relying on browser session cookies.
|
||||
|
||||
@@ -379,21 +380,18 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
||||
- `manager.py` - Core dispatcher: creates threads via `client.threads.create()`, routes commands, keeps Slack/Telegram on `client.runs.wait()`, and uses `client.runs.stream(["messages-tuple", "values"])` for Feishu incremental outbound updates
|
||||
- `base.py` - Abstract `Channel` base class (start/stop/send lifecycle)
|
||||
- `service.py` - Manages lifecycle of all configured channels from `config.yaml`
|
||||
- `slack.py` / `feishu.py` / `telegram.py` / `discord.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||
- `app/gateway/routers/channel_connections.py` - Browser-facing user connection and disconnect APIs
|
||||
- `deerflow.persistence.channel_connections` - SQL-backed user-owned connection, optional credential, connect state, and conversation store
|
||||
- `slack.py` / `feishu.py` / `telegram.py` / `dingtalk.py` - Platform-specific implementations (`feishu.py` tracks the running card `message_id` in memory and patches the same card in place; `dingtalk.py` optionally uses AI Card streaming for in-place updates when `card_template_id` is configured)
|
||||
|
||||
**Message Flow**:
|
||||
1. External platform -> Channel impl -> `MessageBus.publish_inbound()`
|
||||
2. `ChannelManager._dispatch_loop()` consumes from queue
|
||||
3. For user-owned channel connections, incoming messages carry `connection_id`, `owner_user_id`, and `workspace_id`; `owner_user_id` becomes the DeerFlow run `user_id`, while the raw platform user id remains `channel_user_id`
|
||||
4. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||
5. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||
6. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||
7. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||
8. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||
9. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||
10. Outbound → channel callbacks → platform reply
|
||||
3. For chat: look up/create thread through Gateway's LangGraph-compatible API
|
||||
4. Feishu chat: `runs.stream()` → accumulate AI text → publish multiple outbound updates (`is_final=False`) → publish final outbound (`is_final=True`)
|
||||
5. Slack/Telegram chat: `runs.wait()` → extract final response → publish outbound
|
||||
6. Feishu channel sends one running reply card up front, then patches the same card for each outbound update (card JSON sets `config.update_multi=true` for Feishu's patch API requirement)
|
||||
7. DingTalk AI Card mode (when `card_template_id` configured): `runs.stream()` → create card with initial text → stream updates via `PUT /v1.0/card/streaming` → finalize on `is_final=True`. Falls back to `sampleMarkdown` if card creation or streaming fails
|
||||
8. For commands (`/new`, `/status`, `/models`, `/memory`, `/help`): handle locally or query Gateway API
|
||||
9. Outbound → channel callbacks → platform reply
|
||||
|
||||
**Configuration** (`config.yaml` -> `channels`):
|
||||
- `langgraph_url` - LangGraph-compatible Gateway API base URL (default: `http://localhost:8001/api`)
|
||||
@@ -401,16 +399,6 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
||||
- In Docker Compose, IM channels run inside the `gateway` container, so `localhost` points back to that container. Use `http://gateway:8001/api` for `langgraph_url` and `http://gateway:8001` for `gateway_url`, or set `DEER_FLOW_CHANNELS_LANGGRAPH_URL` / `DEER_FLOW_CHANNELS_GATEWAY_URL`.
|
||||
- Per-channel configs: `feishu` (app_id, app_secret), `slack` (bot_token, app_token), `telegram` (bot_token), `dingtalk` (client_id, client_secret, optional `card_template_id` for AI Card streaming)
|
||||
|
||||
**User-owned channel connections** (`config.yaml` -> `channel_connections`):
|
||||
- Disabled by default. It is a user-binding layer on top of the existing `channels.*` runtime config, not a replacement for provider bot credentials.
|
||||
- No public IP, OAuth callback URL, or provider webhook route is required by the current implementation.
|
||||
- Telegram uses a deep-link `/start <code>` flow over the existing long-polling worker. Slack uses `/connect <code>` over the existing Socket Mode worker. Discord uses `/connect <code>` over the existing Gateway worker.
|
||||
- Frontend APIs: `GET /api/channels/providers`, `GET /api/channels/connections`, `POST /api/channels/{provider}/connect`, and `DELETE /api/channels/connections/{connection_id}`.
|
||||
- Browser APIs remain protected by normal Gateway auth/CSRF. Provider messages arrive through the already-configured channel workers.
|
||||
- Slack replies use the configured operator bot token from `channels.slack` unless a future provider-token flow stores per-connection credentials.
|
||||
- Telegram, Slack, and Discord workers resolve incoming platform identities to connection records before reaching `ChannelManager`.
|
||||
- See `backend/docs/IM_CHANNEL_CONNECTIONS.md` for provider setup and operational notes.
|
||||
|
||||
|
||||
### Memory System (`packages/harness/deerflow/agents/memory/`)
|
||||
|
||||
@@ -441,6 +429,12 @@ Bridges external messaging platforms (Feishu, Slack, Telegram, Discord, DingTalk
|
||||
4. Applies updates atomically (temp file + rename) with cache invalidation, skipping duplicate fact content before append
|
||||
5. Next interaction injects top 15 facts + context into `<memory>` tags in system prompt
|
||||
|
||||
**Token counting** (`packages/harness/deerflow/agents/memory/prompt.py`):
|
||||
- `_count_tokens` budgets the injection. In default `tiktoken` mode, the encoding is loaded lazily and cached.
|
||||
- Failed tiktoken loads are cached with a timestamp. During the fixed cooldown (`_TIKTOKEN_RETRY_COOLDOWN_S`, 600s), callers fall back to char estimation immediately instead of re-triggering the blocking BPE download; after the cooldown, transient outages can self-heal without a restart.
|
||||
- In-flight loads are cached as a LOADING sentinel so concurrent callers fall back instead of spawning more blocking threads.
|
||||
- Set `memory.token_counting: char` to skip tiktoken entirely and use the network-free CJK-aware char estimate.
|
||||
|
||||
Focused regression coverage for the updater lives in `backend/tests/test_memory_updater.py`.
|
||||
|
||||
**Configuration** (`config.yaml` → `memory`):
|
||||
@@ -450,6 +444,7 @@ Focused regression coverage for the updater lives in `backend/tests/test_memory_
|
||||
- `model_name` - LLM for updates (null = default model)
|
||||
- `max_facts` / `fact_confidence_threshold` - Fact storage limits (100 / 0.7)
|
||||
- `max_injection_tokens` - Token limit for prompt injection (2000)
|
||||
- `token_counting` - Token counting strategy for the injection budget: `tiktoken` (default, accurate but may download BPE data from a public endpoint on first use — can block for a long time in network-restricted environments, see issues #3402/#3429) or `char` (network-free CJK-aware char estimate, never touches tiktoken)
|
||||
|
||||
### Reflection System (`packages/harness/deerflow/reflection/`)
|
||||
|
||||
|
||||
@@ -11,23 +11,13 @@ from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DISCORD_MAX_MESSAGE_LEN = 2000
|
||||
|
||||
|
||||
def _extract_connect_code(text: str) -> str | None:
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
class DiscordChannel(Channel):
|
||||
"""Discord bot channel.
|
||||
|
||||
@@ -80,7 +70,6 @@ class DiscordChannel(Channel):
|
||||
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._main_loop: asyncio.AbstractEventLoop | None = None
|
||||
self._discord_module = None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -298,10 +287,6 @@ class DiscordChannel(Channel):
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
connect_code = _extract_connect_code(text)
|
||||
if connect_code and await self._bind_connection_from_connect_code(message, connect_code):
|
||||
return
|
||||
|
||||
# --- Determine thread/channel routing and typing target ---
|
||||
thread_id = None
|
||||
chat_id = None
|
||||
@@ -330,7 +315,6 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
@@ -438,7 +422,6 @@ class DiscordChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
inbound = await self._attach_connection_identity(inbound, guild_id=str(guild.id) if guild else None)
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
@@ -453,76 +436,6 @@ class DiscordChannel(Channel):
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage, guild_id: str | None = None) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = None
|
||||
if guild_id:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=guild_id,
|
||||
)
|
||||
if connection is None:
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="discord",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=None,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
async def _bind_connection_from_connect_code(self, message, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="discord", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(message, "Discord connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
guild = getattr(message, "guild", None)
|
||||
channel = getattr(message, "channel", None)
|
||||
author = getattr(message, "author", None)
|
||||
user_id = str(getattr(author, "id", "") or "")
|
||||
if not user_id:
|
||||
await self._send_connection_reply(message, "Discord connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
guild_id = str(getattr(guild, "id", "") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="discord",
|
||||
external_account_id=user_id,
|
||||
external_account_name=getattr(author, "display_name", None) or getattr(author, "name", None),
|
||||
workspace_id=guild_id,
|
||||
workspace_name=getattr(guild, "name", None) if guild is not None else None,
|
||||
metadata={
|
||||
"guild_id": guild_id,
|
||||
"channel_id": str(getattr(channel, "id", "") or ""),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(message, "Discord connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
async def _send_connection_reply(message, text: str) -> None:
|
||||
channel = getattr(message, "channel", None)
|
||||
send = getattr(channel, "send", None)
|
||||
if send is None:
|
||||
return
|
||||
try:
|
||||
await send(text)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to send connection reply")
|
||||
|
||||
def _run_client(self) -> None:
|
||||
self._discord_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(self._discord_loop)
|
||||
|
||||
@@ -670,7 +670,6 @@ class ChannelManager:
|
||||
assistant_id: str = DEFAULT_ASSISTANT_ID,
|
||||
default_session: dict[str, Any] | None = None,
|
||||
channel_sessions: dict[str, Any] | None = None,
|
||||
connection_repo: Any | None = None,
|
||||
) -> None:
|
||||
self.bus = bus
|
||||
self.store = store
|
||||
@@ -680,7 +679,6 @@ class ChannelManager:
|
||||
self._assistant_id = assistant_id
|
||||
self._default_session = _as_dict(default_session)
|
||||
self._channel_sessions = dict(channel_sessions or {})
|
||||
self._connection_repo = connection_repo
|
||||
self._client = None # lazy init — langgraph_sdk async client
|
||||
self._skill_storage: SkillStorage | None = None
|
||||
self._csrf_token = generate_csrf_token()
|
||||
@@ -730,16 +728,12 @@ class ChannelManager:
|
||||
configurable["checkpoint_ns"] = ""
|
||||
configurable["thread_id"] = thread_id
|
||||
|
||||
# ``user_id`` drives DeerFlow-owned memory, files, and thread buckets.
|
||||
# For browser-connected IM channels, prefer the DeerFlow account that
|
||||
# owns the connection. Preserve the raw platform user under
|
||||
# ``channel_user_id`` for platform-facing lookups and audits.
|
||||
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||
# under ``channel_user_id`` for platform-facing lookups.
|
||||
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||
if msg.owner_user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.owner_user_id)
|
||||
elif msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
if msg.user_id:
|
||||
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||
run_context_identity["channel_user_id"] = msg.user_id
|
||||
|
||||
run_context = _merge_dicts(
|
||||
@@ -883,27 +877,10 @@ class ChannelManager:
|
||||
|
||||
# -- chat handling -----------------------------------------------------
|
||||
|
||||
async def _lookup_thread_id(self, msg: InboundMessage) -> str | None:
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
return await self._connection_repo.get_thread_id(
|
||||
msg.connection_id,
|
||||
msg.chat_id,
|
||||
msg.topic_id,
|
||||
)
|
||||
return self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
|
||||
async def _store_thread_id(self, msg: InboundMessage, thread_id: str) -> None:
|
||||
if msg.connection_id and msg.owner_user_id and self._connection_repo is not None:
|
||||
await self._connection_repo.set_thread_id(
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
provider=msg.channel_name,
|
||||
external_conversation_id=msg.chat_id,
|
||||
external_topic_id=msg.topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
return
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
@@ -911,12 +888,6 @@ class ChannelManager:
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
|
||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||
"""Create a new thread through Gateway and store the mapping."""
|
||||
thread = await client.threads.create()
|
||||
thread_id = thread["thread_id"]
|
||||
await self._store_thread_id(msg, thread_id)
|
||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||
return thread_id
|
||||
|
||||
@@ -926,7 +897,7 @@ class ChannelManager:
|
||||
# Look up existing DeerFlow thread.
|
||||
# topic_id may be None (e.g. Telegram private chats) — the store
|
||||
# handles this by using the "channel:chat_id" key without a topic suffix.
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
if thread_id:
|
||||
logger.info("[Manager] reusing thread: thread_id=%s for topic_id=%s", thread_id, msg.topic_id)
|
||||
|
||||
@@ -1013,8 +984,6 @@ class ChannelManager:
|
||||
artifacts=artifacts,
|
||||
attachments=attachments,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
logger.info("[Manager] publishing outbound message to bus: channel=%s, chat_id=%s", msg.channel_name, msg.chat_id)
|
||||
@@ -1078,8 +1047,6 @@ class ChannelManager:
|
||||
text=latest_text,
|
||||
is_final=False,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata),
|
||||
)
|
||||
)
|
||||
@@ -1126,8 +1093,6 @@ class ChannelManager:
|
||||
attachments=attachments,
|
||||
is_final=True,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_response_metadata(msg.metadata, pending_clarification=pending_clarification),
|
||||
)
|
||||
)
|
||||
@@ -1161,10 +1126,16 @@ class ChannelManager:
|
||||
client = self._get_client()
|
||||
thread = await client.threads.create()
|
||||
new_thread_id = thread["thread_id"]
|
||||
await self._store_thread_id(msg, new_thread_id)
|
||||
self.store.set_thread_id(
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
new_thread_id,
|
||||
topic_id=msg.topic_id,
|
||||
user_id=msg.user_id,
|
||||
)
|
||||
reply = "New conversation started."
|
||||
elif reply is None and command == "status":
|
||||
thread_id = await self._lookup_thread_id(msg)
|
||||
thread_id = self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id)
|
||||
reply = f"Active thread: {thread_id}" if thread_id else "No active conversation."
|
||||
elif reply is None and command == "models":
|
||||
reply = await self._fetch_gateway("/api/models", "models")
|
||||
@@ -1203,11 +1174,9 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
text=reply,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
@@ -1243,11 +1212,9 @@ class ChannelManager:
|
||||
outbound = OutboundMessage(
|
||||
channel_name=msg.channel_name,
|
||||
chat_id=msg.chat_id,
|
||||
thread_id=await self._lookup_thread_id(msg) or "",
|
||||
thread_id=self.store.get_thread_id(msg.channel_name, msg.chat_id, topic_id=msg.topic_id) or "",
|
||||
text=error_text,
|
||||
thread_ts=msg.thread_ts,
|
||||
connection_id=msg.connection_id,
|
||||
owner_user_id=msg.owner_user_id,
|
||||
metadata=_slim_metadata(msg.metadata),
|
||||
)
|
||||
await self.bus.publish_outbound(outbound)
|
||||
|
||||
@@ -44,12 +44,6 @@ class InboundMessage:
|
||||
Messages sharing the same ``topic_id`` within a ``chat_id`` will
|
||||
reuse the same DeerFlow thread. When ``None``, each message
|
||||
creates a new thread (one-shot Q&A).
|
||||
connection_id: Optional DeerFlow channel connection id. When present,
|
||||
conversation mapping is scoped by the connection instead of the
|
||||
legacy global ``channel_name:chat_id[:topic_id]`` key.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
Platform user ids stay in ``user_id``.
|
||||
workspace_id: Optional external workspace/guild/team id.
|
||||
files: Optional list of file attachments (platform-specific dicts).
|
||||
metadata: Arbitrary extra data from the channel.
|
||||
created_at: Unix timestamp when the message was created.
|
||||
@@ -62,9 +56,6 @@ class InboundMessage:
|
||||
msg_type: InboundMessageType = InboundMessageType.CHAT
|
||||
thread_ts: str | None = None
|
||||
topic_id: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
workspace_id: str | None = None
|
||||
files: list[dict[str, Any]] = field(default_factory=list)
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
@@ -104,9 +95,6 @@ class OutboundMessage:
|
||||
is_final: Whether this is the final message in the response stream.
|
||||
thread_ts: Optional platform thread identifier for threaded replies.
|
||||
metadata: Arbitrary extra data.
|
||||
connection_id: Optional DeerFlow channel connection id used for
|
||||
connection-specific outbound credentials.
|
||||
owner_user_id: DeerFlow user id that owns the channel connection.
|
||||
created_at: Unix timestamp.
|
||||
"""
|
||||
|
||||
@@ -118,8 +106,6 @@ class OutboundMessage:
|
||||
attachments: list[ResolvedAttachment] = field(default_factory=list)
|
||||
is_final: bool = True
|
||||
thread_ts: str | None = None
|
||||
connection_id: str | None = None
|
||||
owner_user_id: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
|
||||
@@ -52,31 +52,6 @@ def _resolve_service_url(config: dict[str, Any], config_key: str, env_key: str,
|
||||
return default
|
||||
|
||||
|
||||
def _merge_channel_connection_runtime_config(channels_config: dict[str, Any], app_config: AppConfig) -> None:
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return
|
||||
|
||||
|
||||
def _make_connection_repo(app_config: AppConfig):
|
||||
connection_config = getattr(app_config, "channel_connections", None)
|
||||
if connection_config is None or not getattr(connection_config, "enabled", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
except Exception:
|
||||
logger.exception("Failed to import channel connection repository")
|
||||
return None
|
||||
|
||||
session_factory = get_session_factory()
|
||||
if session_factory is None:
|
||||
logger.warning("Channel connections are enabled but database persistence is not available")
|
||||
return None
|
||||
return ChannelConnectionRepository(session_factory)
|
||||
|
||||
|
||||
class ChannelService:
|
||||
"""Manages the lifecycle of all configured IM channels.
|
||||
|
||||
@@ -84,10 +59,9 @@ class ChannelService:
|
||||
instantiates enabled channels, and starts the ChannelManager dispatcher.
|
||||
"""
|
||||
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None, *, connection_repo: Any | None = None) -> None:
|
||||
def __init__(self, channels_config: dict[str, Any] | None = None) -> None:
|
||||
self.bus = MessageBus()
|
||||
self.store = ChannelStore()
|
||||
self._connection_repo = connection_repo
|
||||
config = dict(channels_config or {})
|
||||
langgraph_url = _resolve_service_url(config, "langgraph_url", _CHANNELS_LANGGRAPH_URL_ENV, DEFAULT_LANGGRAPH_URL)
|
||||
gateway_url = _resolve_service_url(config, "gateway_url", _CHANNELS_GATEWAY_URL_ENV, DEFAULT_GATEWAY_URL)
|
||||
@@ -100,7 +74,6 @@ class ChannelService:
|
||||
gateway_url=gateway_url,
|
||||
default_session=default_session if isinstance(default_session, dict) else None,
|
||||
channel_sessions=channel_sessions,
|
||||
connection_repo=connection_repo,
|
||||
)
|
||||
self._channels: dict[str, Any] = {} # name -> Channel instance
|
||||
self._config = config
|
||||
@@ -117,9 +90,8 @@ class ChannelService:
|
||||
# extra fields are allowed by AppConfig (extra="allow")
|
||||
extra = app_config.model_extra or {}
|
||||
if "channels" in extra:
|
||||
channels_config = dict(extra["channels"] or {})
|
||||
_merge_channel_connection_runtime_config(channels_config, app_config)
|
||||
return cls(channels_config=channels_config, connection_repo=_make_connection_repo(app_config))
|
||||
channels_config = extra["channels"]
|
||||
return cls(channels_config=channels_config)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the manager and all enabled channels."""
|
||||
@@ -197,8 +169,6 @@ class ChannelService:
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
if self._connection_repo is not None:
|
||||
config["connection_repo"] = self._connection_repo
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
|
||||
+25
-160
@@ -47,16 +47,6 @@ def _strip_leading_slack_bot_mention(text: str, bot_user_id: str | None) -> str:
|
||||
return text[end + 1 :].lstrip()
|
||||
|
||||
|
||||
def _extract_connect_code(text: str) -> str | None:
|
||||
parts = text.strip().split()
|
||||
if len(parts) < 2:
|
||||
return None
|
||||
command = parts[0].lower()
|
||||
if command in {"/connect", "connect"}:
|
||||
return parts[1]
|
||||
return None
|
||||
|
||||
|
||||
class SlackChannel(Channel):
|
||||
"""Slack IM channel using Socket Mode (WebSocket, no public IP).
|
||||
|
||||
@@ -74,8 +64,6 @@ class SlackChannel(Channel):
|
||||
self._web_client = None
|
||||
self._loop: asyncio.AbstractEventLoop | None = None
|
||||
self._allowed_users = _normalize_allowed_users(config.get("allowed_users", []))
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._web_client_factory = config.get("web_client_factory")
|
||||
configured_bot_user_id = config.get("bot_user_id")
|
||||
self._bot_user_id = str(configured_bot_user_id).lstrip("@") if configured_bot_user_id else None
|
||||
|
||||
@@ -92,28 +80,26 @@ class SlackChannel(Channel):
|
||||
return
|
||||
|
||||
self._SocketModeResponse = SocketModeResponse
|
||||
if self._web_client_factory is None:
|
||||
self._web_client_factory = WebClient
|
||||
|
||||
bot_token = self.config.get("bot_token", "")
|
||||
app_token = self.config.get("app_token", "")
|
||||
|
||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||
if not bot_token:
|
||||
logger.error("Slack HTTP Events mode requires bot_token")
|
||||
return
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("Slack channel started in HTTP Events mode")
|
||||
return
|
||||
|
||||
if not bot_token or not app_token:
|
||||
logger.error("Slack channel requires bot_token and app_token")
|
||||
return
|
||||
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._web_client = WebClient(token=bot_token)
|
||||
if self._bot_user_id is None:
|
||||
try:
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
self._socket_client = SocketModeClient(
|
||||
app_token=app_token,
|
||||
web_client=self._web_client,
|
||||
@@ -138,8 +124,7 @@ class SlackChannel(Channel):
|
||||
logger.info("Slack channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
if not self._web_client:
|
||||
return
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
@@ -152,12 +137,11 @@ class SlackChannel(Channel):
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(_max_retries):
|
||||
try:
|
||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||
await asyncio.to_thread(self._web_client.chat_postMessage, **kwargs)
|
||||
# Add a completion reaction to the thread root
|
||||
if msg.thread_ts:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
self._add_reaction,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"white_check_mark",
|
||||
@@ -181,8 +165,7 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
self._add_reaction_with_client,
|
||||
web_client,
|
||||
self._add_reaction,
|
||||
msg.chat_id,
|
||||
msg.thread_ts,
|
||||
"x",
|
||||
@@ -194,8 +177,7 @@ class SlackChannel(Channel):
|
||||
raise last_exc
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
web_client = await self._get_web_client_for_message(msg)
|
||||
if not web_client:
|
||||
if not self._web_client:
|
||||
return False
|
||||
|
||||
try:
|
||||
@@ -208,7 +190,7 @@ class SlackChannel(Channel):
|
||||
if msg.thread_ts:
|
||||
kwargs["thread_ts"] = msg.thread_ts
|
||||
|
||||
await asyncio.to_thread(web_client.files_upload_v2, **kwargs)
|
||||
await asyncio.to_thread(self._web_client.files_upload_v2, **kwargs)
|
||||
logger.info("[Slack] file uploaded: %s to channel=%s", attachment.filename, msg.chat_id)
|
||||
return True
|
||||
except Exception:
|
||||
@@ -217,38 +199,12 @@ class SlackChannel(Channel):
|
||||
|
||||
# -- internal ----------------------------------------------------------
|
||||
|
||||
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||
self._web_client = self._web_client_factory(token=bot_token)
|
||||
if self._bot_user_id is not None:
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
try:
|
||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||
if user_id is None:
|
||||
auth_get = getattr(auth_info, "get", None)
|
||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||
if isinstance(user_id, str) and user_id:
|
||||
self._bot_user_id = user_id
|
||||
except Exception:
|
||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||
|
||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||
if msg.connection_id and self._connection_repo is not None:
|
||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||
access_token = credentials.get("access_token") if credentials else None
|
||||
if not access_token:
|
||||
return self._web_client
|
||||
if self._web_client_factory is None:
|
||||
from slack_sdk import WebClient
|
||||
|
||||
self._web_client_factory = WebClient
|
||||
return self._web_client_factory(token=access_token)
|
||||
return self._web_client
|
||||
|
||||
@staticmethod
|
||||
def _add_reaction_with_client(web_client, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
try:
|
||||
web_client.reactions_add(
|
||||
self._web_client.reactions_add(
|
||||
channel=channel_id,
|
||||
timestamp=timestamp,
|
||||
name=emoji,
|
||||
@@ -257,12 +213,6 @@ class SlackChannel(Channel):
|
||||
if "already_reacted" not in str(exc):
|
||||
logger.warning("[Slack] failed to add reaction %s: %s", emoji, exc)
|
||||
|
||||
def _add_reaction(self, channel_id: str, timestamp: str, emoji: str) -> None:
|
||||
"""Add an emoji reaction to a message (best-effort, non-blocking)."""
|
||||
if not self._web_client:
|
||||
return
|
||||
self._add_reaction_with_client(self._web_client, channel_id, timestamp, emoji)
|
||||
|
||||
def _send_running_reply(self, channel_id: str, thread_ts: str) -> None:
|
||||
"""Send a 'Working on it......' reply in the thread (called from SDK thread)."""
|
||||
if not self._web_client:
|
||||
@@ -299,15 +249,12 @@ class SlackChannel(Channel):
|
||||
|
||||
# Handle message events (DM or @mention)
|
||||
if etype in ("message", "app_mention"):
|
||||
self._handle_message_event(
|
||||
event,
|
||||
team_id=req.payload.get("team_id") or req.payload.get("team") or event.get("team"),
|
||||
)
|
||||
self._handle_message_event(event)
|
||||
|
||||
except Exception:
|
||||
logger.exception("Error processing Slack event")
|
||||
|
||||
def _handle_message_event(self, event: dict, *, team_id: str | None = None) -> None:
|
||||
def _handle_message_event(self, event: dict) -> None:
|
||||
# Ignore bot messages
|
||||
if event.get("bot_id") or event.get("subtype"):
|
||||
return
|
||||
@@ -325,19 +272,6 @@ class SlackChannel(Channel):
|
||||
if not text:
|
||||
return
|
||||
|
||||
connect_code = _extract_connect_code(text)
|
||||
if connect_code:
|
||||
if self._loop and self._loop.is_running():
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
event=event,
|
||||
team_id=str(team_id or event.get("team") or ""),
|
||||
code=connect_code,
|
||||
),
|
||||
self._loop,
|
||||
)
|
||||
return
|
||||
|
||||
channel_id = event.get("channel", "")
|
||||
thread_ts = event.get("thread_ts") or event.get("ts", "")
|
||||
|
||||
@@ -363,73 +297,4 @@ class SlackChannel(Channel):
|
||||
self._add_reaction(channel_id, event.get("ts", thread_ts), "eyes")
|
||||
# Send "running" reply first (fire-and-forget from SDK thread)
|
||||
self._send_running_reply(channel_id, thread_ts)
|
||||
if self._connection_repo is None:
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
else:
|
||||
asyncio.run_coroutine_threadsafe(self._publish_inbound_with_connection(inbound, team_id=team_id), self._loop)
|
||||
|
||||
async def _publish_inbound_with_connection(self, inbound, *, team_id: str | None = None) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound, team_id=team_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound, *, team_id: str | None = None):
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
workspace_id = str(team_id or inbound.metadata.get("team_id") or "")
|
||||
if not workspace_id:
|
||||
return inbound
|
||||
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="slack",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=workspace_id,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, event: dict, team_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
channel_id = str(event.get("channel") or "")
|
||||
thread_ts = str(event.get("thread_ts") or event.get("ts") or "")
|
||||
state = await self._connection_repo.consume_oauth_state(provider="slack", state=code)
|
||||
if state is None:
|
||||
self._post_connection_reply(channel_id, "Slack connection code is invalid or expired.", thread_ts)
|
||||
return True
|
||||
|
||||
user_id = str(event.get("user") or "")
|
||||
if not user_id or not team_id:
|
||||
self._post_connection_reply(channel_id, "Slack connection could not be completed from this message.", thread_ts)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="slack",
|
||||
external_account_id=user_id,
|
||||
workspace_id=team_id,
|
||||
metadata={
|
||||
"team_id": team_id,
|
||||
"channel_id": channel_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
self._post_connection_reply(channel_id, "Slack connected to DeerFlow.", thread_ts)
|
||||
return True
|
||||
|
||||
def _post_connection_reply(self, channel_id: str, text: str, thread_ts: str | None = None) -> None:
|
||||
if not self._web_client or not channel_id:
|
||||
return
|
||||
kwargs: dict[str, Any] = {"channel": channel_id, "text": text}
|
||||
if thread_ts:
|
||||
kwargs["thread_ts"] = thread_ts
|
||||
try:
|
||||
self._web_client.chat_postMessage(**kwargs)
|
||||
except Exception:
|
||||
logger.exception("[Slack] failed to send connection reply in channel=%s", channel_id)
|
||||
asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._loop)
|
||||
|
||||
@@ -35,7 +35,6 @@ class TelegramChannel(Channel):
|
||||
pass
|
||||
# chat_id -> last sent message_id for threaded replies
|
||||
self._last_bot_message: dict[str, int] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
async def start(self) -> None:
|
||||
if self._running:
|
||||
@@ -177,26 +176,6 @@ class TelegramChannel(Channel):
|
||||
logger.exception("[Telegram] failed to send file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def process_webhook_update(self, payload: dict[str, Any]) -> bool:
|
||||
if not self._application:
|
||||
return False
|
||||
try:
|
||||
from telegram import Update
|
||||
except ImportError:
|
||||
logger.error("python-telegram-bot is not installed. Install it with: uv add python-telegram-bot")
|
||||
return False
|
||||
|
||||
update = Update.de_json(payload, self._application.bot)
|
||||
if update is None:
|
||||
return False
|
||||
|
||||
if self._tg_loop and self._tg_loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self._application.process_update(update), self._tg_loop)
|
||||
await asyncio.wrap_future(future)
|
||||
else:
|
||||
await self._application.process_update(update)
|
||||
return True
|
||||
|
||||
# -- helpers -----------------------------------------------------------
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, reply_to_message_id: int) -> None:
|
||||
@@ -254,63 +233,6 @@ class TelegramChannel(Channel):
|
||||
return True
|
||||
return user_id in self._allowed_users
|
||||
|
||||
@staticmethod
|
||||
def _telegram_display_name(user) -> str:
|
||||
full_name = getattr(user, "full_name", None)
|
||||
if isinstance(full_name, str) and full_name:
|
||||
return full_name
|
||||
username = getattr(user, "username", None)
|
||||
if isinstance(username, str) and username:
|
||||
return username
|
||||
return str(getattr(user, "id", ""))
|
||||
|
||||
async def _bind_connection_from_start_token(self, update, state_token: str) -> bool:
|
||||
if self._connection_repo is None or not state_token:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="telegram", state=state_token)
|
||||
if state is None:
|
||||
await update.message.reply_text("Telegram connection link is invalid or expired.")
|
||||
return True
|
||||
|
||||
owner_user_id = state["owner_user_id"]
|
||||
user_id = str(update.effective_user.id)
|
||||
chat_id = str(update.effective_chat.id)
|
||||
connection = await self._connection_repo.upsert_connection(
|
||||
owner_user_id=owner_user_id,
|
||||
provider="telegram",
|
||||
external_account_id=user_id,
|
||||
external_account_name=self._telegram_display_name(update.effective_user),
|
||||
workspace_id=chat_id,
|
||||
workspace_name=None,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"chat_type": update.effective_chat.type,
|
||||
"telegram_username": getattr(update.effective_user, "username", None),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
logger.info("[Telegram] bound chat=%s user=%s to DeerFlow user=%s connection=%s", chat_id, user_id, owner_user_id, connection["id"])
|
||||
await update.message.reply_text("Telegram connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
if self._connection_repo is None:
|
||||
return inbound
|
||||
|
||||
connection = await self._connection_repo.find_connection_by_external_identity(
|
||||
provider="telegram",
|
||||
external_account_id=inbound.user_id,
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
if connection is None:
|
||||
return inbound
|
||||
|
||||
inbound.connection_id = connection["id"]
|
||||
inbound.owner_user_id = connection["owner_user_id"]
|
||||
inbound.workspace_id = connection.get("workspace_id")
|
||||
return inbound
|
||||
|
||||
def _get_bot_username(self, context) -> str | None:
|
||||
bot = getattr(context, "bot", None)
|
||||
username = getattr(bot, "username", None)
|
||||
@@ -342,11 +264,6 @@ class TelegramChannel(Channel):
|
||||
"""Handle /start command."""
|
||||
if not self._check_user(update.effective_user.id):
|
||||
return
|
||||
args = getattr(context, "args", []) if context is not None else []
|
||||
if args:
|
||||
handled = await self._bind_connection_from_start_token(update, str(args[0]))
|
||||
if handled:
|
||||
return
|
||||
await update.message.reply_text("Welcome to DeerFlow! Send me a message to start a conversation.\nType /help for available commands.")
|
||||
|
||||
async def _process_incoming_with_reply(self, chat_id: str, msg_id: int, inbound: InboundMessage) -> None:
|
||||
@@ -382,7 +299,6 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
@@ -425,7 +341,6 @@ class TelegramChannel(Channel):
|
||||
thread_ts=msg_id,
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(self._process_incoming_with_reply(chat_id, update.message.message_id, inbound), self._main_loop)
|
||||
|
||||
+20
-18
@@ -16,7 +16,6 @@ from app.gateway.routers import (
|
||||
artifacts,
|
||||
assistants_compat,
|
||||
auth,
|
||||
channel_connections,
|
||||
channels,
|
||||
feedback,
|
||||
mcp,
|
||||
@@ -185,21 +184,27 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Pre-warm tiktoken encoding cache so the first memory-injection request
|
||||
# never blocks on the BPE data download (which hits an OpenAI/Azure URL
|
||||
# that may be unreachable in restricted networks — see issue #3402).
|
||||
try:
|
||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||
# When memory.token_counting is "char", token counting never touches
|
||||
# tiktoken, so skip the warm-up entirely (avoids even the 5s probe in
|
||||
# network-restricted deployments — see issue #3429).
|
||||
if startup_config.memory.token_counting == "char":
|
||||
logger.info("memory.token_counting='char'; skipping tiktoken warm-up (network-free token estimation)")
|
||||
else:
|
||||
try:
|
||||
from deerflow.agents.memory.prompt import warm_tiktoken_cache
|
||||
|
||||
warmed = await asyncio.wait_for(
|
||||
asyncio.to_thread(warm_tiktoken_cache),
|
||||
timeout=5,
|
||||
)
|
||||
if warmed:
|
||||
logger.info("tiktoken encoding cache warmed successfully")
|
||||
else:
|
||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback")
|
||||
except TimeoutError:
|
||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback")
|
||||
except Exception:
|
||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||
warmed = await asyncio.wait_for(
|
||||
asyncio.to_thread(warm_tiktoken_cache),
|
||||
timeout=5,
|
||||
)
|
||||
if warmed:
|
||||
logger.info("tiktoken encoding cache warmed successfully")
|
||||
else:
|
||||
logger.warning("tiktoken encoding cache warm-up failed; token counting will use character-based fallback until tiktoken loads successfully")
|
||||
except TimeoutError:
|
||||
logger.warning("tiktoken encoding cache warm-up timed out; token counting will use character-based fallback until tiktoken loads successfully")
|
||||
except Exception:
|
||||
logger.warning("tiktoken warm-up skipped", exc_info=True)
|
||||
|
||||
# Initialize LangGraph runtime components (StreamBridge, RunManager, checkpointer, store)
|
||||
async with langgraph_runtime(app, startup_config):
|
||||
@@ -379,9 +384,6 @@ This gateway provides runtime endpoints for agent runs plus custom endpoints for
|
||||
# Suggestions API is mounted at /api/threads/{thread_id}/suggestions
|
||||
app.include_router(suggestions.router)
|
||||
|
||||
# User-facing IM channel connection API is mounted at /api/channels
|
||||
app.include_router(channel_connections.router)
|
||||
|
||||
# Channels API is mounted at /api/channels
|
||||
app.include_router(channels.router)
|
||||
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
"""Browser-facing APIs for user-owned IM channel bindings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, HTTPException, Request, Response
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||
|
||||
_STATE_TTL_SECONDS = 600
|
||||
|
||||
|
||||
class ChannelProviderResponse(BaseModel):
|
||||
provider: str
|
||||
display_name: str
|
||||
enabled: bool
|
||||
configured: bool
|
||||
connectable: bool
|
||||
unavailable_reason: str | None = None
|
||||
auth_mode: str
|
||||
connection_status: str
|
||||
|
||||
|
||||
class ChannelProvidersResponse(BaseModel):
|
||||
enabled: bool
|
||||
providers: list[ChannelProviderResponse]
|
||||
|
||||
|
||||
class ChannelConnectionResponse(BaseModel):
|
||||
id: str
|
||||
provider: str
|
||||
status: str
|
||||
external_account_id: str | None = None
|
||||
external_account_name: str | None = None
|
||||
workspace_id: str | None = None
|
||||
workspace_name: str | None = None
|
||||
scopes: list[str] = Field(default_factory=list)
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ChannelConnectionsResponse(BaseModel):
|
||||
connections: list[ChannelConnectionResponse]
|
||||
|
||||
|
||||
class ChannelConnectResponse(BaseModel):
|
||||
provider: str
|
||||
mode: str
|
||||
url: str | None = None
|
||||
code: str
|
||||
instruction: str
|
||||
expires_in: int
|
||||
|
||||
|
||||
_PROVIDER_META: dict[str, dict[str, str]] = {
|
||||
"telegram": {"display_name": "Telegram", "auth_mode": "deep_link"},
|
||||
"slack": {"display_name": "Slack", "auth_mode": "binding_code"},
|
||||
"discord": {"display_name": "Discord", "auth_mode": "binding_code"},
|
||||
}
|
||||
|
||||
_RUNTIME_REQUIREMENTS: dict[str, tuple[str, ...]] = {
|
||||
"telegram": ("bot_token",),
|
||||
"slack": ("bot_token", "app_token"),
|
||||
"discord": ("bot_token",),
|
||||
}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> str:
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return str(user.id)
|
||||
|
||||
|
||||
def _get_app_config():
|
||||
from deerflow.config.app_config import get_app_config
|
||||
|
||||
return get_app_config()
|
||||
|
||||
|
||||
def _get_channel_connections_config(request: Request) -> ChannelConnectionsConfig:
|
||||
config = getattr(request.app.state, "channel_connections_config", None)
|
||||
if isinstance(config, ChannelConnectionsConfig):
|
||||
return config
|
||||
return _get_app_config().channel_connections
|
||||
|
||||
|
||||
def _get_channels_config(request: Request) -> dict[str, Any]:
|
||||
state_config = getattr(request.app.state, "channels_config", None)
|
||||
if isinstance(state_config, dict):
|
||||
return state_config
|
||||
|
||||
app_config = _get_app_config()
|
||||
extra = app_config.model_extra or {}
|
||||
channels_config = extra.get("channels")
|
||||
return dict(channels_config) if isinstance(channels_config, dict) else {}
|
||||
|
||||
|
||||
def _get_repository(request: Request, config: ChannelConnectionsConfig) -> ChannelConnectionRepository:
|
||||
repo = getattr(request.app.state, "channel_connection_repo", None)
|
||||
if isinstance(repo, ChannelConnectionRepository):
|
||||
return repo
|
||||
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise HTTPException(status_code=503, detail="Channel connection persistence is not available")
|
||||
|
||||
repo = ChannelConnectionRepository(sf)
|
||||
request.app.state.channel_connection_repo = repo
|
||||
return repo
|
||||
|
||||
|
||||
def _provider_config(config: ChannelConnectionsConfig, provider: str):
|
||||
provider_config = getattr(config, provider, None)
|
||||
if provider_config is None:
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
return provider_config
|
||||
|
||||
|
||||
def _runtime_channel_configured(provider: str, channels_config: dict[str, Any]) -> bool:
|
||||
runtime_config = channels_config.get(provider)
|
||||
if not isinstance(runtime_config, dict) or not runtime_config.get("enabled", False):
|
||||
return False
|
||||
return all(str(runtime_config.get(key) or "").strip() for key in _RUNTIME_REQUIREMENTS[provider])
|
||||
|
||||
|
||||
def _runtime_unavailable_reason(provider: str) -> str:
|
||||
keys = " and ".join(f"channels.{provider}.{key}" for key in _RUNTIME_REQUIREMENTS[provider])
|
||||
return f"Enable and configure channels.{provider} with {keys}."
|
||||
|
||||
|
||||
def _provider_unavailable_reason(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> str | None:
|
||||
provider_config = _provider_config(config, provider)
|
||||
if not provider_config.enabled:
|
||||
return None
|
||||
if not provider_config.configured:
|
||||
if provider == "telegram":
|
||||
return "Configure channel_connections.telegram.bot_username for Telegram deep links."
|
||||
return f"Configure channel_connections.{provider}."
|
||||
if not _runtime_channel_configured(provider, channels_config):
|
||||
return _runtime_unavailable_reason(provider)
|
||||
return None
|
||||
|
||||
|
||||
def _provider_status(
|
||||
config: ChannelConnectionsConfig,
|
||||
channels_config: dict[str, Any],
|
||||
provider: str,
|
||||
) -> tuple[dict[str, bool], str | None]:
|
||||
declared = config.provider_status(provider)
|
||||
unavailable_reason = _provider_unavailable_reason(config, channels_config, provider)
|
||||
configured = declared["configured"] and _runtime_channel_configured(provider, channels_config)
|
||||
return {"enabled": declared["enabled"], "configured": configured}, unavailable_reason
|
||||
|
||||
|
||||
def _new_binding_code() -> str:
|
||||
return secrets.token_urlsafe(16)
|
||||
|
||||
|
||||
async def _create_state(
|
||||
repo: ChannelConnectionRepository,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
) -> str:
|
||||
state = _new_binding_code()
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
def _connect_instruction(provider: str, code: str) -> str:
|
||||
if provider == "telegram":
|
||||
return f"Send /start {code} to the DeerFlow Telegram bot."
|
||||
if provider == "slack":
|
||||
return f"Send /connect {code} to the DeerFlow Slack bot."
|
||||
if provider == "discord":
|
||||
return f"Send /connect {code} to the DeerFlow Discord bot."
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
def _connect_url(config: ChannelConnectionsConfig, provider: str, code: str) -> str | None:
|
||||
if provider == "telegram":
|
||||
provider_config = _provider_config(config, provider)
|
||||
return f"https://t.me/{provider_config.bot_username}?start={code}"
|
||||
if provider in {"slack", "discord"}:
|
||||
return None
|
||||
raise HTTPException(status_code=404, detail="Unknown channel provider")
|
||||
|
||||
|
||||
@router.get("/providers", response_model=ChannelProvidersResponse)
|
||||
async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
channels_config = _get_channels_config(request)
|
||||
repo = None
|
||||
if config.enabled:
|
||||
try:
|
||||
repo = _get_repository(request, config)
|
||||
except HTTPException as exc:
|
||||
if exc.status_code != 503:
|
||||
raise
|
||||
owner_user_id = _get_user_id(request)
|
||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||
by_provider: dict[str, dict[str, Any]] = {}
|
||||
for item in connections:
|
||||
by_provider.setdefault(item["provider"], item)
|
||||
|
||||
providers: list[ChannelProviderResponse] = []
|
||||
for provider, meta in _PROVIDER_META.items():
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
connection = by_provider.get(provider)
|
||||
providers.append(
|
||||
ChannelProviderResponse(
|
||||
provider=provider,
|
||||
display_name=meta["display_name"],
|
||||
enabled=status["enabled"],
|
||||
configured=status["configured"],
|
||||
connectable=status["enabled"] and status["configured"] and unavailable_reason is None,
|
||||
unavailable_reason=unavailable_reason,
|
||||
auth_mode=meta["auth_mode"],
|
||||
connection_status=connection["status"] if connection else "not_connected",
|
||||
)
|
||||
)
|
||||
return ChannelProvidersResponse(enabled=config.enabled, providers=providers)
|
||||
|
||||
|
||||
@router.get("/connections", response_model=ChannelConnectionsResponse)
|
||||
async def get_channel_connections(request: Request) -> ChannelConnectionsResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
return ChannelConnectionsResponse(connections=[])
|
||||
repo = _get_repository(request, config)
|
||||
rows = await repo.list_connections(_get_user_id(request))
|
||||
return ChannelConnectionsResponse(connections=[ChannelConnectionResponse(**row) for row in rows])
|
||||
|
||||
|
||||
@router.delete("/connections/{connection_id}", status_code=204)
|
||||
async def disconnect_channel_connection(connection_id: str, request: Request) -> Response:
|
||||
config = _get_channel_connections_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection_id,
|
||||
owner_user_id=_get_user_id(request),
|
||||
)
|
||||
if not disconnected:
|
||||
raise HTTPException(status_code=404, detail="Channel connection not found")
|
||||
return Response(status_code=204)
|
||||
|
||||
|
||||
@router.post("/{provider}/connect", response_model=ChannelConnectResponse)
|
||||
async def connect_channel_provider(provider: str, request: Request) -> ChannelConnectResponse:
|
||||
config = _get_channel_connections_config(request)
|
||||
channels_config = _get_channels_config(request)
|
||||
if not config.enabled:
|
||||
raise HTTPException(status_code=400, detail="Channel connections are disabled")
|
||||
|
||||
status, unavailable_reason = _provider_status(config, channels_config, provider)
|
||||
if not status["enabled"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not enabled")
|
||||
if unavailable_reason:
|
||||
raise HTTPException(status_code=400, detail=unavailable_reason)
|
||||
if not status["configured"]:
|
||||
raise HTTPException(status_code=400, detail="Channel provider is not configured")
|
||||
|
||||
repo = _get_repository(request, config)
|
||||
code = await _create_state(
|
||||
repo,
|
||||
owner_user_id=_get_user_id(request),
|
||||
provider=provider,
|
||||
)
|
||||
return ChannelConnectResponse(
|
||||
provider=provider,
|
||||
mode=_PROVIDER_META[provider]["auth_mode"],
|
||||
url=_connect_url(config, provider, code),
|
||||
code=code,
|
||||
instruction=_connect_instruction(provider, code),
|
||||
expires_in=_STATE_TTL_SECONDS,
|
||||
)
|
||||
@@ -98,6 +98,7 @@ class MemoryConfigResponse(BaseModel):
|
||||
fact_confidence_threshold: float = Field(..., description="Minimum confidence threshold for facts")
|
||||
injection_enabled: bool = Field(..., description="Whether memory injection is enabled")
|
||||
max_injection_tokens: int = Field(..., description="Maximum tokens for memory injection")
|
||||
token_counting: str = Field(..., description="Token counting strategy for memory injection ('tiktoken' or 'char')")
|
||||
|
||||
|
||||
class MemoryStatusResponse(BaseModel):
|
||||
@@ -310,7 +311,8 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
"max_facts": 100,
|
||||
"fact_confidence_threshold": 0.7,
|
||||
"injection_enabled": true,
|
||||
"max_injection_tokens": 2000
|
||||
"max_injection_tokens": 2000,
|
||||
"token_counting": "tiktoken"
|
||||
}
|
||||
```
|
||||
"""
|
||||
@@ -323,6 +325,7 @@ async def get_memory_config_endpoint() -> MemoryConfigResponse:
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
token_counting=config.token_counting,
|
||||
)
|
||||
|
||||
|
||||
@@ -351,6 +354,7 @@ async def get_memory_status() -> MemoryStatusResponse:
|
||||
fact_confidence_threshold=config.fact_confidence_threshold,
|
||||
injection_enabled=config.injection_enabled,
|
||||
max_injection_tokens=config.max_injection_tokens,
|
||||
token_counting=config.token_counting,
|
||||
),
|
||||
data=MemoryResponse(**memory_data),
|
||||
)
|
||||
|
||||
@@ -315,6 +315,21 @@ async def start_run(
|
||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||
)
|
||||
|
||||
# Stateless run endpoints carry thread_id in the request *body*, so the
|
||||
# @require_permission(owner_check=True) decorator -- which resolves ownership
|
||||
# from the path param -- cannot protect them. Enforce thread ownership here,
|
||||
# before any run is created, so one user cannot start runs on (or read /wait
|
||||
# checkpoint state from) another user's thread. Missing rows (auto-created
|
||||
# temp threads) and NULL-owner rows (shared / pre-auth data) stay accessible
|
||||
# via check_access; only a thread already owned by another user is rejected
|
||||
# with 404, matching thread_runs.py's anti-enumeration behaviour. Internal
|
||||
# channel runs act on behalf of IM users they do not own (see
|
||||
# inject_authenticated_user_context), so the internal system role is exempt.
|
||||
user = getattr(request.state, "user", None)
|
||||
if user is not None and getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||
if not await run_ctx.thread_store.check_access(thread_id, str(user.id)):
|
||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
# IM Channel Connections
|
||||
|
||||
DeerFlow supports user-owned IM channel bindings for Telegram, Slack, and Discord. The feature reuses the existing `channels.*` runtime configuration, so it works in local and private deployments with the same outbound transports already supported by DeerFlow.
|
||||
|
||||
No public IP, OAuth callback URL, or provider webhook is required in this implementation.
|
||||
|
||||
## Configuration
|
||||
|
||||
Configure the actual IM bots under the existing `channels` block:
|
||||
|
||||
```yaml
|
||||
channels:
|
||||
telegram:
|
||||
enabled: true
|
||||
bot_token: $TELEGRAM_BOT_TOKEN
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
bot_token: $SLACK_BOT_TOKEN
|
||||
app_token: $SLACK_APP_TOKEN
|
||||
|
||||
discord:
|
||||
enabled: true
|
||||
bot_token: $DISCORD_BOT_TOKEN
|
||||
```
|
||||
|
||||
Then enable user bindings in `channel_connections`:
|
||||
|
||||
```yaml
|
||||
channel_connections:
|
||||
enabled: true
|
||||
|
||||
telegram:
|
||||
enabled: true
|
||||
bot_username: $TELEGRAM_BOT_USERNAME
|
||||
|
||||
slack:
|
||||
enabled: true
|
||||
|
||||
discord:
|
||||
enabled: true
|
||||
```
|
||||
|
||||
`channel_connections` does not duplicate provider secrets. It only controls the browser-facing connect UI and stores per-user binding records. Telegram needs `bot_username` only so the frontend can open a deep link.
|
||||
|
||||
## Connect Flow
|
||||
|
||||
Telegram:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The Connect button opens `https://t.me/<bot_username>?start=<code>`.
|
||||
- The existing Telegram long-polling worker receives `/start <code>` and binds that Telegram chat/user to the current DeerFlow user.
|
||||
|
||||
Slack:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The UI shows `Send /connect <code> to the DeerFlow Slack bot.`
|
||||
- The existing Slack Socket Mode worker receives the message and binds the Slack user/team to the current DeerFlow user.
|
||||
|
||||
Discord:
|
||||
|
||||
- The frontend creates a short one-time code.
|
||||
- The UI shows `Send /connect <code> to the DeerFlow Discord bot.`
|
||||
- The existing Discord Gateway worker receives the message and binds the Discord user/guild to the current DeerFlow user.
|
||||
|
||||
Codes expire after 10 minutes and are single-use.
|
||||
|
||||
## Runtime Model
|
||||
|
||||
Connection records live in SQL tables under `deerflow.persistence.channel_connections`:
|
||||
|
||||
- `channel_connections`: owner user, provider identity, workspace/guild/team, status, metadata.
|
||||
- `channel_oauth_states`: one-time connect codes and Telegram deep-link state.
|
||||
- `channel_conversations`: connection-scoped IM conversation to DeerFlow thread mapping.
|
||||
- `channel_credentials`: reserved for future provider-token flows, not used by the local/private binding flow.
|
||||
|
||||
Incoming messages that resolve to a connection carry `connection_id`, `owner_user_id`, and `workspace_id`. `ChannelManager` uses `owner_user_id` as the DeerFlow run user id and preserves the raw platform user id as `channel_user_id`.
|
||||
|
||||
## Security Notes
|
||||
|
||||
- Browser APIs remain authenticated and CSRF-protected.
|
||||
- Connect codes are random, short-lived, and single-use.
|
||||
- Provider bot tokens remain in `channels.*` and are never returned to the browser.
|
||||
- This implementation does not add public provider callback or webhook routes.
|
||||
@@ -31,7 +31,8 @@ Current injection format:
|
||||
|
||||
Token counting:
|
||||
- Uses `tiktoken` (`cl100k_base`) when available
|
||||
- Falls back to `len(text) // 4` if tokenizer import fails
|
||||
- Falls back to a network-free CJK-aware character estimate if tokenizer import or encoding load fails
|
||||
(CJK characters count as ~2 chars/token, other characters as ~4 chars/token)
|
||||
|
||||
## Known Gap
|
||||
|
||||
|
||||
@@ -586,7 +586,11 @@ def _get_memory_context(agent_name: str | None = None, *, app_config: AppConfig
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data(agent_name, user_id=get_effective_user_id())
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
memory_content = format_memory_for_injection(
|
||||
memory_data,
|
||||
max_tokens=config.max_injection_tokens,
|
||||
use_tiktoken=(config.token_counting == "tiktoken"),
|
||||
)
|
||||
|
||||
if not memory_content.strip():
|
||||
return ""
|
||||
|
||||
@@ -5,7 +5,9 @@ from __future__ import annotations
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
from typing import Any
|
||||
import threading
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -169,7 +171,26 @@ Return ONLY valid JSON."""
|
||||
# subsequent calls are a dict lookup (no network I/O). Pre-warming at
|
||||
# startup via :func:`warm_tiktoken_cache` avoids blocking a request on the
|
||||
# (potentially slow) first ``get_encoding`` call.
|
||||
_tiktoken_encoding_cache: dict[str, tiktoken.Encoding] = {}
|
||||
#
|
||||
# A *failed* load is cached as a ``(None, monotonic_timestamp)`` tuple so that
|
||||
# a network-restricted environment does not re-attempt the blocking BPE
|
||||
# download on every subsequent call. After ``_TIKTOKEN_RETRY_COOLDOWN_S`` the
|
||||
# failure is allowed to expire so a transient network outage can self-heal back
|
||||
# to accurate tiktoken counting without a process restart. A load already in
|
||||
# progress is cached as ``_TIKTOKEN_ENCODING_LOADING`` so concurrent callers
|
||||
# fall back immediately instead of spawning more blocking
|
||||
# ``tiktoken.get_encoding`` threads. Use the ``memory.token_counting: char``
|
||||
# config to skip tiktoken entirely.
|
||||
_TIKTOKEN_ENCODING_MISSING = object()
|
||||
_TIKTOKEN_ENCODING_LOADING = object()
|
||||
# Cooldown before a *failed* tiktoken load is re-attempted. This is an internal
|
||||
# tuning constant rather than a user-facing config: it only affects how quickly
|
||||
# the default ``tiktoken`` mode self-heals after a transient network outage.
|
||||
# Deployments that want to avoid tiktoken's network dependency entirely should
|
||||
# set ``memory.token_counting: char`` instead of tuning this value.
|
||||
_TIKTOKEN_RETRY_COOLDOWN_S = 600.0
|
||||
_tiktoken_encoding_cache: dict[str, Any] = {}
|
||||
_tiktoken_encoding_cache_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encoding | None:
|
||||
@@ -181,44 +202,91 @@ def _get_tiktoken_encoding(encoding_name: str = "cl100k_base") -> tiktoken.Encod
|
||||
download can block for tens of minutes before the OS TCP timeout kicks in.
|
||||
The caller must therefore be prepared for this to block and should run it
|
||||
off the event loop (e.g. via ``asyncio.to_thread``).
|
||||
|
||||
A failed load is remembered (with a timestamp) so subsequent calls fall
|
||||
back immediately to character-based estimation instead of re-triggering the
|
||||
blocking download. The failure expires after ``_TIKTOKEN_RETRY_COOLDOWN_S``
|
||||
so a transient outage can self-heal without a restart. A load already in
|
||||
progress is also remembered so that a timed-out caller does not leave a
|
||||
window where later requests start more blocking ``get_encoding`` calls.
|
||||
"""
|
||||
if not TIKTOKEN_AVAILABLE:
|
||||
return None
|
||||
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name)
|
||||
if cached is not None:
|
||||
return cached
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
cached = _tiktoken_encoding_cache.get(encoding_name, _TIKTOKEN_ENCODING_MISSING)
|
||||
if cached is _TIKTOKEN_ENCODING_LOADING:
|
||||
return None
|
||||
if isinstance(cached, tuple):
|
||||
# Cached failure: (None, failed_at). Retry only after cooldown.
|
||||
_, failed_at = cached
|
||||
if time.monotonic() - failed_at < _TIKTOKEN_RETRY_COOLDOWN_S:
|
||||
return None
|
||||
cached = _TIKTOKEN_ENCODING_MISSING
|
||||
if cached is not _TIKTOKEN_ENCODING_MISSING:
|
||||
return cast("tiktoken.Encoding", cached)
|
||||
_tiktoken_encoding_cache[encoding_name] = _TIKTOKEN_ENCODING_LOADING
|
||||
|
||||
try:
|
||||
encoding = tiktoken.get_encoding(encoding_name)
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
except Exception:
|
||||
logger.warning("Failed to load tiktoken encoding %r; falling back to char-based estimation", encoding_name, exc_info=True)
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = (None, time.monotonic())
|
||||
return None
|
||||
|
||||
with _tiktoken_encoding_cache_lock:
|
||||
_tiktoken_encoding_cache[encoding_name] = encoding
|
||||
return encoding
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base") -> int:
|
||||
|
||||
def _char_based_token_estimate(text: str) -> int:
|
||||
"""Network-free token estimate that accounts for CJK density.
|
||||
|
||||
The plain ``len(text) // 4`` heuristic is reasonable for English/code
|
||||
(~4 chars per token) but significantly under-estimates token counts for
|
||||
Chinese, Japanese, and Korean text, where the ratio is closer to 1.5-2
|
||||
characters per token. Counting CJK characters separately (~2 chars per
|
||||
token) avoids over-filling the injection budget for CJK-heavy memory
|
||||
content.
|
||||
"""
|
||||
cjk = sum(
|
||||
1
|
||||
for ch in text
|
||||
if "\u4e00" <= ch <= "\u9fff" # CJK Unified Ideographs
|
||||
or "\u3040" <= ch <= "\u30ff" # Hiragana + Katakana
|
||||
or "\uac00" <= ch <= "\ud7a3" # Hangul syllables
|
||||
)
|
||||
return (len(text) - cjk) // 4 + cjk // 2
|
||||
|
||||
|
||||
def _count_tokens(text: str, encoding_name: str = "cl100k_base", *, use_tiktoken: bool = True) -> int:
|
||||
"""Count tokens in text using tiktoken.
|
||||
|
||||
Args:
|
||||
text: The text to count tokens for.
|
||||
encoding_name: The encoding to use (default: cl100k_base for GPT-4/3.5).
|
||||
use_tiktoken: When ``False``, skip tiktoken entirely and use the
|
||||
network-free character-based estimate. This guarantees no BPE
|
||||
download is attempted (see ``memory.token_counting`` config).
|
||||
|
||||
Returns:
|
||||
The number of tokens in the text.
|
||||
"""
|
||||
if not use_tiktoken:
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
encoding = _get_tiktoken_encoding(encoding_name)
|
||||
if encoding is None:
|
||||
# Fallback to character-based estimation if tiktoken is not available
|
||||
# or the encoding failed to load.
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation if tiktoken is not
|
||||
# available or the encoding failed to load.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
try:
|
||||
return len(encoding.encode(text))
|
||||
except Exception:
|
||||
# Fallback to character-based estimation on error
|
||||
return len(text) // 4
|
||||
# Fallback to CJK-aware character estimation on error.
|
||||
return _char_based_token_estimate(text)
|
||||
|
||||
|
||||
def warm_tiktoken_cache() -> bool:
|
||||
@@ -248,12 +316,15 @@ def _coerce_confidence(value: Any, default: float = 0.0) -> float:
|
||||
return max(0.0, min(1.0, confidence))
|
||||
|
||||
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000) -> str:
|
||||
def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2000, *, use_tiktoken: bool = True) -> str:
|
||||
"""Format memory data for injection into system prompt.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data dictionary.
|
||||
max_tokens: Maximum tokens to use (counted via tiktoken for accuracy).
|
||||
use_tiktoken: When ``False``, all token counting uses the network-free
|
||||
character-based estimate instead of tiktoken (see
|
||||
``memory.token_counting`` config). Defaults to ``True``.
|
||||
|
||||
Returns:
|
||||
Formatted memory string for system prompt injection.
|
||||
@@ -315,10 +386,10 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
# Compute token count for existing sections once, then account
|
||||
# incrementally for each fact line to avoid full-string re-tokenization.
|
||||
base_text = "\n\n".join(sections)
|
||||
base_tokens = _count_tokens(base_text) if base_text else 0
|
||||
base_tokens = _count_tokens(base_text, use_tiktoken=use_tiktoken) if base_text else 0
|
||||
# Account for the separator between existing sections and the facts section.
|
||||
facts_header = "Facts:\n"
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header) if base_text else _count_tokens(facts_header)
|
||||
separator_tokens = _count_tokens("\n\n" + facts_header, use_tiktoken=use_tiktoken) if base_text else _count_tokens(facts_header, use_tiktoken=use_tiktoken)
|
||||
running_tokens = base_tokens + separator_tokens
|
||||
|
||||
fact_lines: list[str] = []
|
||||
@@ -339,7 +410,7 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
# Each additional line is preceded by a newline (except the first).
|
||||
line_text = ("\n" + line) if fact_lines else line
|
||||
line_tokens = _count_tokens(line_text)
|
||||
line_tokens = _count_tokens(line_text, use_tiktoken=use_tiktoken)
|
||||
|
||||
if running_tokens + line_tokens <= max_tokens:
|
||||
fact_lines.append(line)
|
||||
@@ -355,8 +426,9 @@ def format_memory_for_injection(memory_data: dict[str, Any], max_tokens: int = 2
|
||||
|
||||
result = "\n\n".join(sections)
|
||||
|
||||
# Use accurate token counting with tiktoken
|
||||
token_count = _count_tokens(result)
|
||||
# Use accurate token counting with tiktoken (or the char-based estimate
|
||||
# when use_tiktoken is False).
|
||||
token_count = _count_tokens(result, use_tiktoken=use_tiktoken)
|
||||
if token_count > max_tokens:
|
||||
# Truncate to fit within token limit
|
||||
# Estimate characters to remove based on token ratio
|
||||
|
||||
@@ -1141,6 +1141,7 @@ class DeerFlowClient:
|
||||
"fact_confidence_threshold": config.fact_confidence_threshold,
|
||||
"injection_enabled": config.injection_enabled,
|
||||
"max_injection_tokens": config.max_injection_tokens,
|
||||
"token_counting": config.token_counting,
|
||||
}
|
||||
|
||||
def get_memory_status(self) -> dict:
|
||||
|
||||
@@ -67,11 +67,13 @@ def resolve_agent_dir(name: str, *, user_id: str | None = None) -> Path:
|
||||
paths = get_paths()
|
||||
effective_user = user_id or get_effective_user_id()
|
||||
user_path = paths.user_agent_dir(effective_user, name)
|
||||
if user_path.exists():
|
||||
# Require config.yaml to confirm this is a genuine agent directory,
|
||||
# not a leftover from memory/storage writes (see #3390).
|
||||
if user_path.exists() and (user_path / "config.yaml").exists():
|
||||
return user_path
|
||||
|
||||
legacy_path = paths.agent_dir(name)
|
||||
if legacy_path.exists():
|
||||
if legacy_path.exists() and (legacy_path / "config.yaml").exists():
|
||||
return legacy_path
|
||||
|
||||
return user_path
|
||||
|
||||
@@ -11,7 +11,6 @@ from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from deerflow.config.acp_config import ACPAgentConfig, load_acp_config_from_dict
|
||||
from deerflow.config.agents_api_config import AgentsApiConfig, load_agents_api_config_from_dict
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
from deerflow.config.checkpointer_config import CheckpointerConfig, load_checkpointer_config_from_dict
|
||||
from deerflow.config.database_config import DatabaseConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
@@ -117,7 +116,6 @@ class AppConfig(BaseModel):
|
||||
subagents: SubagentsAppConfig = Field(default_factory=SubagentsAppConfig, description="Subagent runtime configuration")
|
||||
guardrails: GuardrailsConfig = Field(default_factory=GuardrailsConfig, description="Guardrail middleware configuration")
|
||||
circuit_breaker: CircuitBreakerConfig = Field(default_factory=CircuitBreakerConfig, description="LLM circuit breaker configuration")
|
||||
channel_connections: ChannelConnectionsConfig = Field(default_factory=ChannelConnectionsConfig, description="User-facing IM channel connection configuration")
|
||||
loop_detection: LoopDetectionConfig = Field(default_factory=LoopDetectionConfig, description="Loop detection middleware configuration")
|
||||
safety_finish_reason: SafetyFinishReasonConfig = Field(default_factory=SafetyFinishReasonConfig, description="Provider safety-filter finish_reason interception middleware configuration")
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
"""Configuration for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SlackChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class TelegramChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
bot_username: str = ""
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return bool(self.bot_username)
|
||||
|
||||
|
||||
class DiscordChannelConnectionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
@property
|
||||
def configured(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class ChannelConnectionsConfig(BaseModel):
|
||||
"""Top-level config for browser-connectable IM channels."""
|
||||
|
||||
enabled: bool = False
|
||||
slack: SlackChannelConnectionConfig = Field(default_factory=SlackChannelConnectionConfig)
|
||||
telegram: TelegramChannelConnectionConfig = Field(default_factory=TelegramChannelConnectionConfig)
|
||||
discord: DiscordChannelConnectionConfig = Field(default_factory=DiscordChannelConnectionConfig)
|
||||
|
||||
def provider_status(self, provider: str) -> dict[str, bool]:
|
||||
config = getattr(self, provider, None)
|
||||
if config is None:
|
||||
return {"enabled": False, "configured": False}
|
||||
enabled = bool(config.enabled)
|
||||
return {
|
||||
"enabled": enabled,
|
||||
"configured": enabled and bool(config.configured),
|
||||
}
|
||||
@@ -1,5 +1,7 @@
|
||||
"""Configuration for memory mechanism."""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
@@ -60,6 +62,17 @@ class MemoryConfig(BaseModel):
|
||||
le=8000,
|
||||
description="Maximum tokens to use for memory injection",
|
||||
)
|
||||
token_counting: Literal["tiktoken", "char"] = Field(
|
||||
default="tiktoken",
|
||||
description=(
|
||||
"Token counting strategy for memory-injection budgeting. "
|
||||
"'tiktoken' is accurate but the encoding's BPE data may be "
|
||||
"downloaded from a public network endpoint on first use, which "
|
||||
"can block for a long time in network-restricted environments "
|
||||
"(see issue #3402/#3429). 'char' uses a network-free "
|
||||
"CJK-aware character-based estimate and never touches tiktoken."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Global configuration instance
|
||||
|
||||
@@ -47,7 +47,7 @@ def make_safe_user_id(raw: str) -> str:
|
||||
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
||||
if sanitized == raw:
|
||||
return raw
|
||||
digest = hashlib.sha256(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||
return f"{sanitized}-{digest}"
|
||||
|
||||
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
"""User-owned IM channel connection persistence."""
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.persistence.channel_connections.sql import (
|
||||
ChannelConnectionRepository,
|
||||
ChannelCredentialCipher,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChannelConnectionRepository",
|
||||
"ChannelConnectionRow",
|
||||
"ChannelConversationRow",
|
||||
"ChannelCredentialCipher",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
]
|
||||
@@ -1,111 +0,0 @@
|
||||
"""ORM models for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import JSON, DateTime, ForeignKey, Index, Integer, String, Text, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class ChannelConnectionRow(Base):
|
||||
__tablename__ = "channel_connections"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
status: Mapped[str] = mapped_column(String(32), nullable=False, default="connected")
|
||||
|
||||
external_account_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
external_account_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
workspace_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
workspace_name: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
bot_user_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
|
||||
scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||
capabilities_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
last_seen_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
last_error_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"owner_user_id",
|
||||
"provider",
|
||||
"external_account_id",
|
||||
"workspace_id",
|
||||
name="uq_channel_connection_owner_provider_identity",
|
||||
),
|
||||
Index("idx_channel_connections_event_lookup", "provider", "workspace_id", "bot_user_id"),
|
||||
)
|
||||
|
||||
|
||||
class ChannelCredentialRow(Base):
|
||||
__tablename__ = "channel_credentials"
|
||||
|
||||
connection_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
encrypted_access_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
encrypted_refresh_token: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
token_type: Mapped[str | None] = mapped_column(String(32), nullable=True)
|
||||
expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
refresh_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
encrypted_extra_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
|
||||
|
||||
class ChannelOAuthStateRow(Base):
|
||||
__tablename__ = "channel_oauth_states"
|
||||
|
||||
state_hash: Mapped[str] = mapped_column(String(128), primary_key=True)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
code_verifier_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
nonce_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
|
||||
redirect_after: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
requested_scopes_json: Mapped[list] = mapped_column(JSON, default=list)
|
||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
consumed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
|
||||
|
||||
class ChannelConversationRow(Base):
|
||||
__tablename__ = "channel_conversations"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
connection_id: Mapped[str] = mapped_column(
|
||||
String(64),
|
||||
ForeignKey("channel_connections.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
owner_user_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
provider: Mapped[str] = mapped_column(String(32), nullable=False, index=True)
|
||||
external_conversation_id: Mapped[str] = mapped_column(String(128), nullable=False)
|
||||
external_topic_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, default=_utc_now, onupdate=_utc_now)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"connection_id",
|
||||
"external_conversation_id",
|
||||
"external_topic_id",
|
||||
name="uq_channel_conversation_connection_external",
|
||||
),
|
||||
)
|
||||
@@ -1,346 +0,0 @@
|
||||
"""SQL repository for user-owned IM channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.utils.time import coerce_iso
|
||||
|
||||
|
||||
class ChannelCredentialCipher:
|
||||
"""Encrypts provider credentials before they are persisted."""
|
||||
|
||||
def __init__(self, fernet: Fernet) -> None:
|
||||
self._fernet = fernet
|
||||
|
||||
@classmethod
|
||||
def from_key(cls, key: str) -> ChannelCredentialCipher:
|
||||
digest = hashlib.sha256(key.encode("utf-8")).digest()
|
||||
return cls(Fernet(base64.urlsafe_b64encode(digest)))
|
||||
|
||||
def encrypt_text(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
return "fernet:v1:" + self._fernet.encrypt(value.encode("utf-8")).decode("ascii")
|
||||
|
||||
def decrypt_text(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
token = value.removeprefix("fernet:v1:")
|
||||
return self._fernet.decrypt(token.encode("ascii")).decode("utf-8")
|
||||
|
||||
|
||||
class ChannelConnectionRepository:
|
||||
"""Persistence facade for channel connections, credentials, and conversations."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
cipher: ChannelCredentialCipher | None = None,
|
||||
) -> None:
|
||||
self.session_factory = session_factory
|
||||
self._cipher = cipher
|
||||
|
||||
async def close(self) -> None:
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
@staticmethod
|
||||
def _new_id() -> str:
|
||||
return uuid.uuid4().hex
|
||||
|
||||
@staticmethod
|
||||
def _normalize_optional_identity(value: str | None) -> str:
|
||||
return value or ""
|
||||
|
||||
@staticmethod
|
||||
def _coerce_datetime(value: datetime | None) -> datetime | None:
|
||||
if value is None or value.tzinfo is not None:
|
||||
return value
|
||||
return value.replace(tzinfo=UTC)
|
||||
|
||||
def _encrypt_optional_secret(self, value: str | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
return self._cipher.encrypt_text(value)
|
||||
|
||||
@staticmethod
|
||||
def _connection_to_dict(row: ChannelConnectionRow) -> dict[str, Any]:
|
||||
data = row.to_dict()
|
||||
data["external_account_id"] = data["external_account_id"] or None
|
||||
data["workspace_id"] = data["workspace_id"] or None
|
||||
data["scopes"] = data.pop("scopes_json") or []
|
||||
data["capabilities"] = data.pop("capabilities_json") or {}
|
||||
data["metadata"] = data.pop("metadata_json") or {}
|
||||
for key in ("created_at", "updated_at", "last_seen_at", "last_error_at"):
|
||||
value = data.get(key)
|
||||
if isinstance(value, datetime):
|
||||
data[key] = coerce_iso(value)
|
||||
return data
|
||||
|
||||
async def upsert_connection(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
external_account_id: str | None = None,
|
||||
external_account_name: str | None = None,
|
||||
workspace_id: str | None = None,
|
||||
workspace_name: str | None = None,
|
||||
bot_user_id: str | None = None,
|
||||
scopes: list[str] | None = None,
|
||||
capabilities: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str = "connected",
|
||||
) -> dict[str, Any]:
|
||||
external_account_id_value = self._normalize_optional_identity(external_account_id)
|
||||
workspace_id_value = self._normalize_optional_identity(workspace_id)
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(ChannelConnectionRow).where(
|
||||
ChannelConnectionRow.owner_user_id == owner_user_id,
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == external_account_id_value,
|
||||
ChannelConnectionRow.workspace_id == workspace_id_value,
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ChannelConnectionRow(
|
||||
id=self._new_id(),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_account_id=external_account_id_value,
|
||||
workspace_id=workspace_id_value,
|
||||
)
|
||||
session.add(row)
|
||||
|
||||
row.status = status
|
||||
row.external_account_name = external_account_name
|
||||
row.workspace_name = workspace_name
|
||||
row.bot_user_id = bot_user_id
|
||||
row.scopes_json = list(scopes or [])
|
||||
row.capabilities_json = dict(capabilities or {})
|
||||
row.metadata_json = dict(metadata or {})
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._connection_to_dict(row)
|
||||
|
||||
async def list_connections(self, owner_user_id: str) -> list[dict[str, Any]]:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(select(ChannelConnectionRow).where(ChannelConnectionRow.owner_user_id == owner_user_id).order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc()))
|
||||
return [self._connection_to_dict(row) for row in result.scalars()]
|
||||
|
||||
async def disconnect_connection(self, *, connection_id: str, owner_user_id: str) -> bool:
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelConnectionRow, connection_id)
|
||||
if row is None or row.owner_user_id != owner_user_id:
|
||||
return False
|
||||
|
||||
row.status = "revoked"
|
||||
credential = await session.get(ChannelCredentialRow, connection_id)
|
||||
if credential is not None:
|
||||
await session.delete(credential)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def store_credentials(
|
||||
self,
|
||||
connection_id: str,
|
||||
*,
|
||||
access_token: str | None,
|
||||
refresh_token: str | None = None,
|
||||
token_type: str | None = None,
|
||||
expires_at: datetime | None = None,
|
||||
refresh_expires_at: datetime | None = None,
|
||||
extra: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
if self._cipher is None:
|
||||
raise RuntimeError("channel connection encryption key is required")
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
row = ChannelCredentialRow(connection_id=connection_id)
|
||||
session.add(row)
|
||||
row.encrypted_access_token = self._cipher.encrypt_text(access_token)
|
||||
row.encrypted_refresh_token = self._cipher.encrypt_text(refresh_token)
|
||||
row.token_type = token_type
|
||||
row.expires_at = expires_at
|
||||
row.refresh_expires_at = refresh_expires_at
|
||||
row.encrypted_extra_json = self._cipher.encrypt_text(json.dumps(extra or {}, ensure_ascii=False))
|
||||
row.version = (row.version or 0) + 1
|
||||
await session.commit()
|
||||
|
||||
async def get_credentials(self, connection_id: str) -> dict[str, Any] | None:
|
||||
if self._cipher is None:
|
||||
return None
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelCredentialRow, connection_id)
|
||||
if row is None:
|
||||
return None
|
||||
extra_raw = self._cipher.decrypt_text(row.encrypted_extra_json)
|
||||
return {
|
||||
"connection_id": row.connection_id,
|
||||
"access_token": self._cipher.decrypt_text(row.encrypted_access_token),
|
||||
"refresh_token": self._cipher.decrypt_text(row.encrypted_refresh_token),
|
||||
"token_type": row.token_type,
|
||||
"expires_at": self._coerce_datetime(row.expires_at),
|
||||
"refresh_expires_at": self._coerce_datetime(row.refresh_expires_at),
|
||||
"extra": json.loads(extra_raw) if extra_raw else {},
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def hash_state(state: str) -> str:
|
||||
return hashlib.sha256(state.encode("utf-8")).hexdigest()
|
||||
|
||||
async def create_oauth_state(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
state: str,
|
||||
expires_at: datetime,
|
||||
code_verifier: str | None = None,
|
||||
nonce_hash: str | None = None,
|
||||
redirect_after: str | None = None,
|
||||
requested_scopes: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
row = ChannelOAuthStateRow(
|
||||
state_hash=self.hash_state(state),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
||||
nonce_hash=nonce_hash,
|
||||
redirect_after=redirect_after,
|
||||
requested_scopes_json=list(requested_scopes or []),
|
||||
metadata_json=dict(metadata or {}),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
async with self.session_factory() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(ChannelOAuthStateRow).where(
|
||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||
ChannelOAuthStateRow.provider == provider,
|
||||
)
|
||||
)
|
||||
return len(list(result.scalars()))
|
||||
|
||||
async def consume_oauth_state(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
state: str,
|
||||
now: datetime | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
current_time = now or datetime.now(UTC)
|
||||
async with self.session_factory() as session:
|
||||
row = await session.get(ChannelOAuthStateRow, self.hash_state(state))
|
||||
if row is None or row.provider != provider or row.consumed_at is not None:
|
||||
return None
|
||||
expires_at = self._coerce_datetime(row.expires_at)
|
||||
if expires_at is not None and expires_at < current_time:
|
||||
return None
|
||||
|
||||
row.consumed_at = current_time
|
||||
await session.commit()
|
||||
return {
|
||||
"owner_user_id": row.owner_user_id,
|
||||
"provider": row.provider,
|
||||
"requested_scopes": row.requested_scopes_json or [],
|
||||
"metadata": row.metadata_json or {},
|
||||
"redirect_after": row.redirect_after,
|
||||
}
|
||||
|
||||
async def find_connection_by_external_identity(
|
||||
self,
|
||||
*,
|
||||
provider: str,
|
||||
external_account_id: str,
|
||||
workspace_id: str | None = None,
|
||||
) -> dict[str, Any] | None:
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(ChannelConnectionRow)
|
||||
.where(
|
||||
ChannelConnectionRow.provider == provider,
|
||||
ChannelConnectionRow.external_account_id == self._normalize_optional_identity(external_account_id),
|
||||
ChannelConnectionRow.workspace_id == self._normalize_optional_identity(workspace_id),
|
||||
ChannelConnectionRow.status == "connected",
|
||||
)
|
||||
.order_by(ChannelConnectionRow.updated_at.desc(), ChannelConnectionRow.id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._connection_to_dict(row) if row is not None else None
|
||||
|
||||
async def set_thread_id(
|
||||
self,
|
||||
*,
|
||||
connection_id: str,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
external_conversation_id: str,
|
||||
thread_id: str,
|
||||
external_topic_id: str | None = None,
|
||||
) -> None:
|
||||
topic_id = external_topic_id or ""
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(ChannelConversationRow).where(
|
||||
ChannelConversationRow.connection_id == connection_id,
|
||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||
ChannelConversationRow.external_topic_id == topic_id,
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
row = ChannelConversationRow(
|
||||
id=self._new_id(),
|
||||
connection_id=connection_id,
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
external_conversation_id=external_conversation_id,
|
||||
external_topic_id=topic_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
session.add(row)
|
||||
else:
|
||||
row.thread_id = thread_id
|
||||
row.owner_user_id = owner_user_id
|
||||
row.provider = provider
|
||||
await session.commit()
|
||||
|
||||
async def get_thread_id(
|
||||
self,
|
||||
connection_id: str,
|
||||
external_conversation_id: str,
|
||||
external_topic_id: str | None = None,
|
||||
) -> str | None:
|
||||
async with self.session_factory() as session:
|
||||
stmt = select(ChannelConversationRow.thread_id).where(
|
||||
ChannelConversationRow.connection_id == connection_id,
|
||||
ChannelConversationRow.external_conversation_id == external_conversation_id,
|
||||
ChannelConversationRow.external_topic_id == (external_topic_id or ""),
|
||||
)
|
||||
return (await session.execute(stmt)).scalar_one_or_none()
|
||||
@@ -14,26 +14,10 @@ its storage implementation lives in ``deerflow.runtime.events.store.db`` and
|
||||
there is no matching entity directory.
|
||||
"""
|
||||
|
||||
from deerflow.persistence.channel_connections.model import (
|
||||
ChannelConnectionRow,
|
||||
ChannelConversationRow,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
from deerflow.persistence.feedback.model import FeedbackRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.run.model import RunRow
|
||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
__all__ = [
|
||||
"ChannelConnectionRow",
|
||||
"ChannelConversationRow",
|
||||
"ChannelCredentialRow",
|
||||
"ChannelOAuthStateRow",
|
||||
"FeedbackRow",
|
||||
"RunEventRow",
|
||||
"RunRow",
|
||||
"ThreadMetaRow",
|
||||
"UserRow",
|
||||
]
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
|
||||
|
||||
@@ -36,7 +36,6 @@ dependencies = [
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"aiosqlite>=0.19",
|
||||
"alembic>=1.13",
|
||||
"cryptography>=43.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
||||
@@ -39,8 +39,6 @@ def test_public_paths(path: str):
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/channels/providers",
|
||||
"/api/channels/slack/connect",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""Tests for user-facing IM channel connection configuration."""
|
||||
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
|
||||
def test_channel_connections_disabled_by_default():
|
||||
config = ChannelConnectionsConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.slack.enabled is False
|
||||
assert config.telegram.enabled is False
|
||||
assert config.discord.enabled is False
|
||||
|
||||
|
||||
def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key():
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {
|
||||
"enabled": True,
|
||||
"bot_username": "deerflow_bot",
|
||||
},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.provider_status("telegram") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("slack") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("discord") == {"enabled": True, "configured": True}
|
||||
|
||||
|
||||
def test_provider_status_reports_disabled_and_unknown_providers():
|
||||
config = ChannelConnectionsConfig.model_validate({"enabled": True})
|
||||
|
||||
assert config.provider_status("slack") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("telegram") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("discord") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("unknown") == {"enabled": False, "configured": False}
|
||||
@@ -1,202 +0,0 @@
|
||||
"""Tests for per-user IM channel connection persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from deerflow.persistence.channel_connections import (
|
||||
ChannelConnectionRepository,
|
||||
ChannelConnectionRow,
|
||||
ChannelCredentialCipher,
|
||||
ChannelCredentialRow,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestChannelConnectionRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_connections_are_listed_per_owner(self, repo):
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
external_account_name="Alice",
|
||||
workspace_id="T1",
|
||||
workspace_name="Team One",
|
||||
scopes=["chat:write"],
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
external_account_name="Bob",
|
||||
workspace_id="T1",
|
||||
workspace_name="Team One",
|
||||
scopes=["chat:write"],
|
||||
)
|
||||
|
||||
results = await repo.list_connections("alice")
|
||||
|
||||
assert [item["id"] for item in results] == [alice["id"]]
|
||||
assert results[0]["owner_user_id"] == "alice"
|
||||
assert results[0]["provider"] == "slack"
|
||||
assert results[0]["scopes"] == ["chat:write"]
|
||||
assert "encrypted_access_token" not in results[0]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
|
||||
first = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice",
|
||||
workspace_id=None,
|
||||
workspace_name=None,
|
||||
status="pending",
|
||||
)
|
||||
second = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice Telegram",
|
||||
workspace_id=None,
|
||||
workspace_name=None,
|
||||
status="connected",
|
||||
)
|
||||
|
||||
assert second["id"] == first["id"]
|
||||
assert second["status"] == "connected"
|
||||
assert second["external_account_name"] == "Alice Telegram"
|
||||
assert len(await repo.list_connections("alice")) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
||||
|
||||
await repo.store_credentials(
|
||||
connection["id"],
|
||||
access_token="xoxb-secret-access-token",
|
||||
refresh_token="secret-refresh-token",
|
||||
token_type="Bearer",
|
||||
expires_at=expires_at,
|
||||
extra={"bot_user_id": "B123"},
|
||||
)
|
||||
|
||||
async with repo.session_factory() as session:
|
||||
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
|
||||
assert row.encrypted_access_token is not None
|
||||
assert "xoxb-secret-access-token" not in row.encrypted_access_token
|
||||
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
|
||||
assert "B123" not in (row.encrypted_extra_json or "")
|
||||
|
||||
credentials = await repo.get_credentials(connection["id"])
|
||||
|
||||
assert credentials is not None
|
||||
assert credentials["access_token"] == "xoxb-secret-access-token"
|
||||
assert credentials["refresh_token"] == "secret-refresh-token"
|
||||
assert credentials["token_type"] == "Bearer"
|
||||
assert credentials["expires_at"] == expires_at
|
||||
assert credentials["extra"] == {"bot_user_id": "B123"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_conversations_are_scoped_by_connection(self, repo):
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
bob = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
workspace_id="T1",
|
||||
)
|
||||
|
||||
await repo.set_thread_id(
|
||||
connection_id=alice["id"],
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_conversation_id="C-shared",
|
||||
external_topic_id="1710000000.000100",
|
||||
thread_id="thread-alice",
|
||||
)
|
||||
await repo.set_thread_id(
|
||||
connection_id=bob["id"],
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_conversation_id="C-shared",
|
||||
external_topic_id="1710000000.000100",
|
||||
thread_id="thread-bob",
|
||||
)
|
||||
|
||||
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
await repo.store_credentials(connection["id"], access_token="secret-token")
|
||||
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id="alice",
|
||||
)
|
||||
|
||||
assert disconnected is True
|
||||
async with repo.session_factory() as session:
|
||||
connection_row = await session.get(ChannelConnectionRow, connection["id"])
|
||||
credential_row = await session.get(ChannelCredentialRow, connection["id"])
|
||||
assert connection_row is not None
|
||||
assert connection_row.status == "revoked"
|
||||
assert credential_row is None
|
||||
assert (
|
||||
await repo.find_connection_by_external_identity(
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_disconnect_connection_is_owner_scoped(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id="bob",
|
||||
)
|
||||
|
||||
assert disconnected is False
|
||||
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
||||
@@ -1,322 +0,0 @@
|
||||
"""Router tests for browser-connectable IM channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.routers import channel_connections
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
|
||||
def _user() -> User:
|
||||
return User(
|
||||
id=UUID("11111111-2222-3333-4444-555555555555"),
|
||||
email="alice@example.com",
|
||||
password_hash="x",
|
||||
system_role="user",
|
||||
)
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(get_session_factory())
|
||||
|
||||
|
||||
def _make_app(config: ChannelConnectionsConfig, repo, channels_config: dict | None = None):
|
||||
app = make_authed_test_app(user_factory=_user)
|
||||
app.state.channel_connections_config = config
|
||||
app.state.channel_connection_repo = repo
|
||||
app.state.channels_config = channels_config or {}
|
||||
app.include_router(channel_connections.router)
|
||||
return app
|
||||
|
||||
|
||||
def _enabled_connections_config() -> ChannelConnectionsConfig:
|
||||
return ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _channels_config() -> dict:
|
||||
return {
|
||||
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||
"slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"},
|
||||
"discord": {"enabled": True, "bot_token": "discord-bot"},
|
||||
}
|
||||
|
||||
|
||||
def test_get_providers_uses_existing_channels_config(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["enabled"] is True
|
||||
by_provider = {item["provider"]: item for item in body["providers"]}
|
||||
assert by_provider["telegram"]["configured"] is True
|
||||
assert by_provider["telegram"]["auth_mode"] == "deep_link"
|
||||
assert by_provider["slack"]["configured"] is True
|
||||
assert by_provider["slack"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["discord"]["configured"] is True
|
||||
assert by_provider["discord"]["auth_mode"] == "binding_code"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}})
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["telegram"]["configured"] is True
|
||||
assert by_provider["slack"]["configured"] is False
|
||||
assert by_provider["slack"]["connectable"] is False
|
||||
assert "channels.slack" in by_provider["slack"]["unavailable_reason"]
|
||||
assert by_provider["discord"]["configured"] is False
|
||||
assert "channels.discord" in by_provider["discord"]["unavailable_reason"]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connections():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U-old",
|
||||
workspace_id="T-old",
|
||||
status="revoked",
|
||||
)
|
||||
await anyio.sleep(0.01)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U-new",
|
||||
workspace_id="T-new",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connections)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["slack"]["connection_status"] == "connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_connections_returns_current_user_connections_only(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connections():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice",
|
||||
status="connected",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="other-user",
|
||||
provider="telegram",
|
||||
external_account_id="99",
|
||||
external_account_name="Bob",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connections)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/connections")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert len(body["connections"]) == 1
|
||||
assert body["connections"][0]["provider"] == "telegram"
|
||||
assert body["connections"][0]["external_account_id"] == "42"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/telegram/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "telegram"
|
||||
assert body["mode"] == "deep_link"
|
||||
assert body["url"].startswith("https://t.me/deerflow_bot?start=")
|
||||
assert body["code"]
|
||||
assert "/start" in body["instruction"]
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/slack/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "slack"
|
||||
assert body["mode"] == "binding_code"
|
||||
assert body["url"] is None
|
||||
assert len(body["code"]) >= 22
|
||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/discord/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "discord"
|
||||
assert body["mode"] == "binding_code"
|
||||
assert body["url"] is None
|
||||
assert body["code"]
|
||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot."
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_unconfigured_runtime_channel_returns_400(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, {})
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/slack/connect")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "channels.slack" in response.json()["detail"]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
status="connected",
|
||||
)
|
||||
return connection["id"]
|
||||
|
||||
connection_id = anyio.run(seed_connection)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "revoked"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_connection_is_current_user_scoped(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="other-user",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
status="connected",
|
||||
)
|
||||
return connection["id"]
|
||||
|
||||
connection_id = anyio.run(seed_connection)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections("other-user"))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
@@ -504,17 +504,6 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
|
||||
return mock_client
|
||||
|
||||
|
||||
async def _make_channel_connection_repo(tmp_path: Path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("test-channel-key"),
|
||||
)
|
||||
|
||||
|
||||
def _make_stream_part(event: str, data):
|
||||
return SimpleNamespace(event=event, data=data)
|
||||
|
||||
@@ -2345,22 +2334,6 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["user_id"] == "123456"
|
||||
assert run_context["channel_user_id"] == "123456"
|
||||
|
||||
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self):
|
||||
manager = self._manager()
|
||||
msg = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U-platform",
|
||||
owner_user_id="deerflow-user-1",
|
||||
connection_id="connection-1",
|
||||
text="hi",
|
||||
)
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == "deerflow-user-1"
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
||||
from deerflow.config.paths import make_safe_user_id
|
||||
|
||||
@@ -2385,80 +2358,6 @@ class TestResolveRunParamsUserId:
|
||||
assert "channel_user_id" not in run_context
|
||||
|
||||
|
||||
class TestChannelManagerConnectionRouting:
|
||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
||||
from app.channels.manager import ChannelManager
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
async def go():
|
||||
repo = await _make_channel_connection_repo(tmp_path)
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
bob = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
workspace_id="T1",
|
||||
)
|
||||
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=tmp_path / "legacy-store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, connection_repo=repo)
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
mock_client.threads.create = AsyncMock(
|
||||
side_effect=[
|
||||
{"thread_id": "thread-alice"},
|
||||
{"thread_id": "thread-bob"},
|
||||
]
|
||||
)
|
||||
manager._client = mock_client
|
||||
|
||||
await manager._handle_chat(
|
||||
InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C-shared",
|
||||
user_id="U-alice",
|
||||
owner_user_id="alice",
|
||||
connection_id=alice["id"],
|
||||
text="hello",
|
||||
thread_ts="1710000000.000100",
|
||||
topic_id="1710000000.000100",
|
||||
)
|
||||
)
|
||||
await manager._handle_chat(
|
||||
InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C-shared",
|
||||
user_id="U-bob",
|
||||
owner_user_id="bob",
|
||||
connection_id=bob["id"],
|
||||
text="hello",
|
||||
thread_ts="1710000000.000100",
|
||||
topic_id="1710000000.000100",
|
||||
)
|
||||
)
|
||||
|
||||
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||
assert store.list_entries() == []
|
||||
|
||||
first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"]
|
||||
second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"]
|
||||
assert first_context["user_id"] == "alice"
|
||||
assert first_context["channel_user_id"] == "U-alice"
|
||||
assert second_context["user_id"] == "bob"
|
||||
assert second_context["channel_user_id"] == "U-bob"
|
||||
|
||||
try:
|
||||
_run(go())
|
||||
finally:
|
||||
_run(close_engine())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelService tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -3276,62 +3175,6 @@ class TestChannelService:
|
||||
|
||||
assert service._config == {"telegram": {"enabled": False}}
|
||||
|
||||
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(self):
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert service._config == {}
|
||||
|
||||
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(self):
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={
|
||||
"channels": {
|
||||
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||
"slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"},
|
||||
"discord": {"enabled": True, "bot_token": "discord-bot-token"},
|
||||
}
|
||||
},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert service._config["telegram"]["bot_token"] == "telegram-token"
|
||||
assert service._config["slack"]["app_token"] == "xapp"
|
||||
assert service._config["discord"]["bot_token"] == "discord-bot-token"
|
||||
|
||||
def test_connection_repo_is_forwarded_to_manager(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
repo = object()
|
||||
service = ChannelService(channels_config={}, connection_repo=repo)
|
||||
|
||||
assert service.manager._connection_repo is repo
|
||||
|
||||
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
||||
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
||||
import logging
|
||||
|
||||
@@ -2472,6 +2472,7 @@ class TestGatewayConformance:
|
||||
mem_cfg.fact_confidence_threshold = 0.7
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
mem_cfg.token_counting = "tiktoken"
|
||||
|
||||
with patch("deerflow.config.memory_config.get_memory_config", return_value=mem_cfg):
|
||||
result = client.get_memory_config()
|
||||
@@ -2479,6 +2480,7 @@ class TestGatewayConformance:
|
||||
parsed = MemoryConfigResponse(**result)
|
||||
assert parsed.enabled is True
|
||||
assert parsed.max_facts == 100
|
||||
assert parsed.token_counting == "tiktoken"
|
||||
|
||||
def test_get_memory_status(self, client):
|
||||
mem_cfg = MagicMock()
|
||||
@@ -2489,6 +2491,7 @@ class TestGatewayConformance:
|
||||
mem_cfg.fact_confidence_threshold = 0.7
|
||||
mem_cfg.injection_enabled = True
|
||||
mem_cfg.max_injection_tokens = 2000
|
||||
mem_cfg.token_counting = "tiktoken"
|
||||
|
||||
memory_data = {
|
||||
"version": "1.0",
|
||||
@@ -2514,6 +2517,7 @@ class TestGatewayConformance:
|
||||
|
||||
parsed = MemoryStatusResponse(**result)
|
||||
assert parsed.config.enabled is True
|
||||
assert parsed.config.token_counting == "tiktoken"
|
||||
assert parsed.data.version == "1.0"
|
||||
|
||||
|
||||
|
||||
@@ -233,15 +233,3 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token mismatch."
|
||||
|
||||
|
||||
def test_channel_posts_require_double_submit_csrf():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/channels/slack/connect",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||
|
||||
@@ -203,6 +203,79 @@ class TestLoadAgentConfig:
|
||||
assert cfg.name == "legacy-agent"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 3b. resolve_agent_dir — memory-only directory fallback (#3390)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestResolveAgentDirMemoryOnlyFallback:
|
||||
"""Regression tests for #3390.
|
||||
|
||||
When memory is enabled, the first conversation creates a user-isolated
|
||||
agent directory containing only ``memory.json`` (no ``config.yaml``).
|
||||
On the next turn ``resolve_agent_dir`` must fall through to the legacy
|
||||
shared layout instead of returning the incomplete user directory.
|
||||
"""
|
||||
|
||||
def test_user_dir_with_only_memory_falls_back_to_legacy(self, tmp_path):
|
||||
"""User dir has memory.json but no config.yaml → use legacy dir."""
|
||||
from deerflow.config.agents_config import resolve_agent_dir
|
||||
|
||||
# Legacy agent with full config
|
||||
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||
(legacy_dir / "SOUL.md").write_text("legacy soul", encoding="utf-8")
|
||||
|
||||
# User dir created by memory write — no config.yaml
|
||||
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||
user_dir.mkdir(parents=True)
|
||||
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||
|
||||
assert result == legacy_dir
|
||||
|
||||
def test_user_dir_with_config_takes_priority(self, tmp_path):
|
||||
"""User dir with config.yaml should still win over legacy."""
|
||||
from deerflow.config.agents_config import resolve_agent_dir
|
||||
|
||||
# Legacy
|
||||
legacy_dir = tmp_path / "agents" / "my-agent"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "config.yaml").write_text("name: my-agent\n", encoding="utf-8")
|
||||
|
||||
# User dir with full config (migrated)
|
||||
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||
user_dir.mkdir(parents=True)
|
||||
(user_dir / "config.yaml").write_text("name: my-agent\nmodel: gpt-4\n", encoding="utf-8")
|
||||
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||
result = resolve_agent_dir("my-agent", user_id="u1")
|
||||
|
||||
assert result == user_dir
|
||||
|
||||
def test_load_config_falls_back_when_user_dir_is_memory_only(self, tmp_path):
|
||||
"""End-to-end: load_agent_config works when user dir only has memory.json."""
|
||||
config_dict = {"name": "my-agent", "description": "Legacy agent", "model": "deepseek-v3"}
|
||||
_write_agent(tmp_path, "my-agent", config_dict)
|
||||
|
||||
# Simulate memory write creating user dir without config
|
||||
user_dir = tmp_path / "users" / "u1" / "agents" / "my-agent"
|
||||
user_dir.mkdir(parents=True)
|
||||
(user_dir / "memory.json").write_text("{}", encoding="utf-8")
|
||||
|
||||
with patch("deerflow.config.agents_config.get_paths", return_value=_make_paths(tmp_path)), patch("deerflow.config.agents_config.get_effective_user_id", return_value="u1"):
|
||||
from deerflow.config.agents_config import load_agent_config
|
||||
|
||||
cfg = load_agent_config("my-agent", user_id="u1")
|
||||
|
||||
assert cfg.name == "my-agent"
|
||||
assert cfg.model == "deepseek-v3"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# 4. load_agent_soul
|
||||
# ===========================================================================
|
||||
|
||||
@@ -1,88 +0,0 @@
|
||||
"""Discord connection routing tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.discord import DiscordChannel
|
||||
from app.channels.message_bus import InboundMessage, MessageBus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("discord-secret"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="discord",
|
||||
external_account_id="987",
|
||||
external_account_name="Alice",
|
||||
status="connected",
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||
)
|
||||
inbound = InboundMessage(
|
||||
channel_name="discord",
|
||||
chat_id="C123",
|
||||
user_id="987",
|
||||
text="hello",
|
||||
)
|
||||
|
||||
attached = await channel._attach_connection_identity(inbound, guild_id="G123")
|
||||
|
||||
assert attached.connection_id == connection["id"]
|
||||
assert attached.owner_user_id == "alice"
|
||||
assert attached.workspace_id is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discord_connect_command_binds_gateway_identity(repo):
|
||||
state = "discord-bind-code"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="discord",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||
)
|
||||
message = MagicMock()
|
||||
message.author.id = 987
|
||||
message.author.display_name = "Alice"
|
||||
message.guild.id = 123
|
||||
message.guild.name = "Deer Guild"
|
||||
message.channel.id = 456
|
||||
message.channel.send = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(message, state)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "discord"
|
||||
assert connections[0]["external_account_id"] == "987"
|
||||
assert connections[0]["external_account_name"] == "Alice"
|
||||
assert connections[0]["workspace_id"] == "123"
|
||||
assert connections[0]["workspace_name"] == "Deer Guild"
|
||||
assert connections[0]["metadata"]["channel_id"] == "456"
|
||||
message.channel.send.assert_awaited_once()
|
||||
@@ -192,7 +192,7 @@ def test_build_acp_section_uses_explicit_app_config_without_global_config(monkey
|
||||
|
||||
def test_get_memory_context_uses_explicit_app_config_without_global_config(monkeypatch):
|
||||
explicit_config = SimpleNamespace(
|
||||
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234),
|
||||
memory=SimpleNamespace(enabled=True, injection_enabled=True, max_injection_tokens=1234, token_counting="tiktoken"),
|
||||
)
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
@@ -204,9 +204,10 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
||||
captured["user_id"] = user_id
|
||||
return {"facts": []}
|
||||
|
||||
def fake_format_memory_for_injection(memory_data, *, max_tokens):
|
||||
def fake_format_memory_for_injection(memory_data, *, max_tokens, use_tiktoken=True):
|
||||
captured["memory_data"] = memory_data
|
||||
captured["max_tokens"] = max_tokens
|
||||
captured["use_tiktoken"] = use_tiktoken
|
||||
return "remember this"
|
||||
|
||||
monkeypatch.setattr("deerflow.config.memory_config.get_memory_config", fail_get_memory_config)
|
||||
@@ -223,6 +224,7 @@ def test_get_memory_context_uses_explicit_app_config_without_global_config(monke
|
||||
"user_id": "user-1",
|
||||
"memory_data": {"facts": []},
|
||||
"max_tokens": 1234,
|
||||
"use_tiktoken": True,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,7 @@ def test_format_memory_sorts_facts_by_confidence_desc() -> None:
|
||||
|
||||
def test_format_memory_respects_budget_when_adding_facts(monkeypatch) -> None:
|
||||
# Make token counting deterministic for this test by counting characters.
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base": len(text))
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt._count_tokens", lambda text, encoding_name="cl100k_base", *, use_tiktoken=True: len(text))
|
||||
|
||||
memory_data = {
|
||||
"user": {},
|
||||
|
||||
@@ -1,154 +0,0 @@
|
||||
"""Slack connection tests for user-owned channel bindings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.channels.message_bus import MessageBus, OutboundMessage
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("slack-secret"),
|
||||
)
|
||||
|
||||
|
||||
def test_slack_connect_command_binds_socket_mode_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path)
|
||||
state = "slack-bind-code"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="slack",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo},
|
||||
)
|
||||
channel._web_client = MagicMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
event={
|
||||
"user": "U123",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
},
|
||||
team_id="T123",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "slack"
|
||||
assert connections[0]["external_account_id"] == "U123"
|
||||
assert connections[0]["workspace_id"] == "T123"
|
||||
assert connections[0]["metadata"]["channel_id"] == "C123"
|
||||
channel._web_client.chat_postMessage.assert_called_once()
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
repo = AsyncMock()
|
||||
repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"}
|
||||
web_client = MagicMock()
|
||||
web_client_factory = MagicMock(return_value=web_client)
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"connection_repo": repo,
|
||||
"web_client_factory": web_client_factory,
|
||||
},
|
||||
)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
thread_id="thread-1",
|
||||
text="hello",
|
||||
connection_id="connection-1",
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
repo.get_credentials.assert_awaited_once_with("connection-1")
|
||||
web_client_factory.assert_called_once_with(token="xoxb-connection-token")
|
||||
web_client.chat_postMessage.assert_called_once()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
class FakeWebClient:
|
||||
def __init__(self, token: str) -> None:
|
||||
self.token = token
|
||||
self.messages: list[dict] = []
|
||||
|
||||
def auth_test(self):
|
||||
return {"user_id": "B-http"}
|
||||
|
||||
def chat_postMessage(self, **kwargs):
|
||||
self.messages.append(kwargs)
|
||||
|
||||
slack_sdk = ModuleType("slack_sdk")
|
||||
slack_sdk.WebClient = FakeWebClient
|
||||
socket_mode = ModuleType("slack_sdk.socket_mode")
|
||||
socket_mode.SocketModeClient = object
|
||||
response = ModuleType("slack_sdk.socket_mode.response")
|
||||
response.SocketModeResponse = object
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
|
||||
|
||||
async def go():
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"bot_token": "xoxb-operator",
|
||||
"event_delivery": "http",
|
||||
"connection_repo": MagicMock(),
|
||||
},
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
assert channel._running is True
|
||||
assert channel._web_client is not None
|
||||
assert channel._web_client.token == "xoxb-operator"
|
||||
assert channel._bot_user_id == "B-http"
|
||||
|
||||
channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
|
||||
|
||||
assert channel._web_client.messages == [
|
||||
{
|
||||
"channel": "C123",
|
||||
"text": "Slack connected to DeerFlow.",
|
||||
"thread_ts": "1710000000.000100",
|
||||
}
|
||||
]
|
||||
await channel.stop()
|
||||
|
||||
anyio.run(go)
|
||||
@@ -0,0 +1,173 @@
|
||||
"""Cross-user isolation for the stateless ``POST /api/runs/stream`` and ``/wait`` endpoints.
|
||||
|
||||
These endpoints receive ``thread_id`` in the request body, so the
|
||||
``@require_permission(owner_check=True)`` decorator — which reads the
|
||||
``thread_id`` *path* parameter — cannot protect them. The owner check
|
||||
lives inside ``services.start_run()`` instead; this suite pins it at the
|
||||
HTTP layer so the gap cannot silently reopen.
|
||||
|
||||
Strategy
|
||||
--------
|
||||
``app.state.run_manager.create_or_reject`` raises ``ConflictError``, so a
|
||||
request that *passes* the owner check deterministically short-circuits
|
||||
with 409 before any agent code runs. The two outcomes:
|
||||
|
||||
- 404 + ``create_or_reject`` never awaited -> blocked by the owner check
|
||||
- 409 + ``create_or_reject`` awaited -> passed the owner check
|
||||
|
||||
The thread store is a real ``MemoryThreadMetaStore`` (not a mock) so the
|
||||
``check_access`` semantics under test — missing row allows, ``user_id``
|
||||
NULL allows, foreign owner denies — are exercised through real code.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.routers import runs
|
||||
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime import ConflictError
|
||||
|
||||
USER_A = User(email="owner-a@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||
USER_B = User(email="intruder-b@example.com", password_hash="x", system_role="user", id=uuid4())
|
||||
INTERNAL_USER = SimpleNamespace(id="default", system_role="internal")
|
||||
|
||||
THREAD_A = "thread-owned-by-a"
|
||||
THREAD_SHARED = "thread-shared-null-owner"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_app_config():
|
||||
"""Inject a minimal AppConfig so the allowed path (which builds a
|
||||
RunContext via ``get_config()``) never reads config.yaml from disk."""
|
||||
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _make_thread_store() -> MemoryThreadMetaStore:
|
||||
store = MemoryThreadMetaStore(InMemoryStore())
|
||||
|
||||
async def _seed():
|
||||
await store.create(THREAD_A, user_id=str(USER_A.id))
|
||||
await store.create(THREAD_SHARED, user_id=None)
|
||||
|
||||
asyncio.run(_seed())
|
||||
return store
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _client(user):
|
||||
"""Yield a ``TestClient`` authenticated as ``user`` plus the stubbed
|
||||
``create_or_reject`` mock, closing the client (and its anyio portal /
|
||||
background threads) on exit.
|
||||
|
||||
``create_or_reject`` raises ``ConflictError`` so a request that passes the
|
||||
owner check short-circuits to 409 before any agent code runs.
|
||||
"""
|
||||
app = make_authed_test_app(user_factory=lambda: user)
|
||||
app.include_router(runs.router)
|
||||
app.state.thread_store = _make_thread_store()
|
||||
app.state.stream_bridge = MagicMock()
|
||||
app.state.checkpointer = MagicMock()
|
||||
app.state.store = MagicMock()
|
||||
app.state.run_events_config = None
|
||||
app.state.run_event_store = MagicMock()
|
||||
run_manager = MagicMock()
|
||||
run_manager.create_or_reject = AsyncMock(side_effect=ConflictError("sentinel: owner check passed"))
|
||||
app.state.run_manager = run_manager
|
||||
with TestClient(app) as client:
|
||||
yield client, run_manager.create_or_reject
|
||||
|
||||
|
||||
def _body(thread_id: str | None = None) -> dict:
|
||||
if thread_id is None:
|
||||
return {}
|
||||
return {"config": {"configurable": {"thread_id": thread_id}}}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Denied: another user's thread
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stream_cross_user_returns_404():
|
||||
"""User B cannot start a run on user A's thread via /api/runs/stream."""
|
||||
with _client(USER_B) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||
assert response.status_code == 404
|
||||
assert response.json()["detail"] == f"Thread {THREAD_A} not found"
|
||||
create_or_reject.assert_not_awaited()
|
||||
|
||||
|
||||
def test_wait_cross_user_returns_404_without_channel_values():
|
||||
"""User B cannot read user A's checkpoint state via /api/runs/wait."""
|
||||
with _client(USER_B) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||
assert response.status_code == 404
|
||||
assert response.json() == {"detail": f"Thread {THREAD_A} not found"}
|
||||
create_or_reject.assert_not_awaited()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Allowed: owner, fresh/untracked/shared threads, internal role
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_stream_owner_passes_owner_check():
|
||||
"""User A reaches run creation on their own thread (409 sentinel)."""
|
||||
with _client(USER_A) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_wait_owner_passes_owner_check():
|
||||
with _client(USER_A) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/wait", json=_body(THREAD_A))
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_without_thread_id_passes_owner_check():
|
||||
"""Stateless run with no thread_id auto-creates a thread — never blocked."""
|
||||
with _client(USER_B) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body())
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_untracked_thread_passes_owner_check():
|
||||
"""A thread_id with no thread_meta row (untracked legacy) stays accessible."""
|
||||
with _client(USER_B) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body("never-created-thread"))
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_shared_thread_passes_owner_check():
|
||||
"""A thread_meta row with user_id NULL (shared / pre-auth data) stays accessible."""
|
||||
with _client(USER_B) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_internal_role_bypasses_owner_check():
|
||||
"""IM channels run with the internal system role on behalf of platform
|
||||
users whose threads they do not own — the owner check must not break them."""
|
||||
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Tests for Telegram deep-link channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.telegram import TelegramChannel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path: Path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("telegram-secret"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"):
|
||||
update = MagicMock()
|
||||
update.effective_user.id = user_id
|
||||
update.effective_user.username = "alice"
|
||||
update.effective_user.full_name = "Alice Example"
|
||||
update.effective_chat.id = chat_id
|
||||
update.effective_chat.type = chat_type
|
||||
update.message.text = text
|
||||
update.message.message_id = 55
|
||||
update.message.reply_to_message = None
|
||||
update.message.reply_text = AsyncMock()
|
||||
return update
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_start_with_deep_link_state_binds_telegram_chat(repo):
|
||||
state = "telegram-bind-state"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="telegram",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = TelegramChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "test-token", "connection_repo": repo},
|
||||
)
|
||||
update = _telegram_update(text=f"/start {state}")
|
||||
context = MagicMock()
|
||||
context.args = [state]
|
||||
|
||||
await channel._cmd_start(update, context)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "telegram"
|
||||
assert connections[0]["external_account_id"] == "42"
|
||||
assert connections[0]["external_account_name"] == "Alice Example"
|
||||
assert connections[0]["workspace_id"] == "100"
|
||||
assert connections[0]["metadata"]["chat_type"] == "private"
|
||||
update.message.reply_text.assert_awaited_once()
|
||||
assert "connected" in update.message.reply_text.await_args.args[0].lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_bound_telegram_message_publishes_connection_identity(repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice Example",
|
||||
workspace_id="100",
|
||||
metadata={"chat_type": "private"},
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(
|
||||
bus=bus,
|
||||
config={"bot_token": "test-token", "connection_repo": repo},
|
||||
)
|
||||
channel._main_loop = __import__("asyncio").get_event_loop()
|
||||
channel._send_running_reply = AsyncMock()
|
||||
|
||||
await channel._on_text(_telegram_update(text="hello"), None)
|
||||
inbound = await bus.get_inbound()
|
||||
|
||||
assert inbound.connection_id == connection["id"]
|
||||
assert inbound.owner_user_id == "deerflow-user-1"
|
||||
assert inbound.workspace_id == "100"
|
||||
assert inbound.user_id == "42"
|
||||
assert inbound.chat_id == "100"
|
||||
assert inbound.text == "hello"
|
||||
@@ -5,18 +5,22 @@ Verifies:
|
||||
- ``_count_tokens`` falls back to character estimation when tiktoken is
|
||||
unavailable or the encoding fails to load.
|
||||
- ``warm_tiktoken_cache`` populates the cache on success.
|
||||
- An in-flight tiktoken load prevents duplicate blocking downloads.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
from deerflow.agents.memory.prompt import (
|
||||
_count_tokens,
|
||||
_get_tiktoken_encoding,
|
||||
_tiktoken_encoding_cache,
|
||||
format_memory_for_injection,
|
||||
warm_tiktoken_cache,
|
||||
)
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _get_tiktoken_encoding
|
||||
@@ -62,14 +66,103 @@ class TestGetTiktokenEncoding:
|
||||
assert enc is fake_enc
|
||||
tiktoken.get_encoding.assert_not_called()
|
||||
|
||||
def test_returns_none_and_warns_on_get_encoding_failure(self, monkeypatch):
|
||||
def test_returns_none_and_caches_failure_sentinel(self, monkeypatch):
|
||||
"""A failed load is cached (with a timestamp) so it is not re-attempted (no repeated network download)."""
|
||||
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||
import tiktoken
|
||||
|
||||
monkeypatch.setattr(tiktoken, "get_encoding", mock.Mock(side_effect=OSError("download failed")))
|
||||
get_encoding = mock.Mock(side_effect=OSError("download failed"))
|
||||
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||
|
||||
result = _get_tiktoken_encoding("bogus_encoding")
|
||||
assert result is None
|
||||
assert "bogus_encoding" not in _tiktoken_encoding_cache
|
||||
# The failure is remembered as a (None, timestamp) tuple.
|
||||
assert "bogus_encoding" in _tiktoken_encoding_cache
|
||||
cached = _tiktoken_encoding_cache["bogus_encoding"]
|
||||
assert isinstance(cached, tuple)
|
||||
assert cached[0] is None
|
||||
|
||||
# A second call must NOT re-attempt get_encoding (avoids re-blocking on
|
||||
# the network download in restricted environments — see #3429).
|
||||
result2 = _get_tiktoken_encoding("bogus_encoding")
|
||||
assert result2 is None
|
||||
assert get_encoding.call_count == 1
|
||||
|
||||
# Cleanup module-level cache to avoid cross-test leakage.
|
||||
_tiktoken_encoding_cache.pop("bogus_encoding", None)
|
||||
|
||||
def test_failure_self_heals_after_cooldown(self, monkeypatch):
|
||||
"""After the retry cooldown expires, a transient failure is re-attempted and can recover."""
|
||||
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||
import tiktoken
|
||||
|
||||
fake_enc = mock.Mock()
|
||||
# First call fails, second call (after cooldown) succeeds.
|
||||
get_encoding = mock.Mock(side_effect=[OSError("transient outage"), fake_enc])
|
||||
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||
|
||||
# Initial failure is cached.
|
||||
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||
assert get_encoding.call_count == 1
|
||||
|
||||
# Within the cooldown window: no retry, immediate fallback.
|
||||
assert _get_tiktoken_encoding("flaky_encoding") is None
|
||||
assert get_encoding.call_count == 1
|
||||
|
||||
# Simulate the cooldown having elapsed by ageing the cached timestamp.
|
||||
from deerflow.agents.memory import prompt as prompt_module
|
||||
|
||||
_, _failed_at = _tiktoken_encoding_cache["flaky_encoding"]
|
||||
_tiktoken_encoding_cache["flaky_encoding"] = (
|
||||
None,
|
||||
_failed_at - prompt_module._TIKTOKEN_RETRY_COOLDOWN_S - 1,
|
||||
)
|
||||
|
||||
# Now the load is retried and recovers to accurate counting.
|
||||
assert _get_tiktoken_encoding("flaky_encoding") is fake_enc
|
||||
assert get_encoding.call_count == 2
|
||||
|
||||
_tiktoken_encoding_cache.pop("flaky_encoding", None)
|
||||
|
||||
def test_in_flight_load_returns_none_without_duplicate_get_encoding(self, monkeypatch):
|
||||
"""Concurrent callers must not start duplicate blocking BPE downloads."""
|
||||
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||
import tiktoken
|
||||
|
||||
started = threading.Event()
|
||||
release = threading.Event()
|
||||
fake_enc = mock.Mock()
|
||||
|
||||
def slow_get_encoding(_name):
|
||||
started.set()
|
||||
assert release.wait(timeout=2), "test timed out waiting to release slow get_encoding"
|
||||
return fake_enc
|
||||
|
||||
get_encoding = mock.Mock(side_effect=slow_get_encoding)
|
||||
monkeypatch.setattr(tiktoken, "get_encoding", get_encoding)
|
||||
|
||||
result: dict[str, object | None] = {}
|
||||
|
||||
def load_encoding():
|
||||
result["encoding"] = _get_tiktoken_encoding("slow_encoding")
|
||||
|
||||
thread = threading.Thread(target=load_encoding)
|
||||
thread.start()
|
||||
try:
|
||||
assert started.wait(timeout=1), "slow get_encoding did not start"
|
||||
|
||||
# While the first call is still blocked, a second call should see
|
||||
# the in-flight sentinel and fall back immediately instead of
|
||||
# starting another potentially long network download.
|
||||
assert _get_tiktoken_encoding("slow_encoding") is None
|
||||
assert get_encoding.call_count == 1
|
||||
finally:
|
||||
release.set()
|
||||
thread.join(timeout=2)
|
||||
_tiktoken_encoding_cache.pop("slow_encoding", None)
|
||||
|
||||
assert result["encoding"] is fake_enc
|
||||
assert get_encoding.call_count == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -115,6 +208,45 @@ class TestCountTokens:
|
||||
result = _count_tokens(text, encoding_name="test_enc")
|
||||
assert result == len(text) // 4
|
||||
|
||||
def test_use_tiktoken_false_returns_char_estimate_without_touching_tiktoken(self, monkeypatch):
|
||||
"""use_tiktoken=False must never call tiktoken (guarantees no BPE download)."""
|
||||
# Spy on both the encoding loader and tiktoken.get_encoding directly.
|
||||
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||
loader_spy = mock.Mock(side_effect=AssertionError("_get_tiktoken_encoding must not be called"))
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt._get_tiktoken_encoding", loader_spy)
|
||||
|
||||
text = "Hello, world! This is a network-free count."
|
||||
result = _count_tokens(text, use_tiktoken=False)
|
||||
assert result == len(text) // 4
|
||||
get_encoding_spy.assert_not_called()
|
||||
loader_spy.assert_not_called()
|
||||
|
||||
def test_cjk_estimate_is_denser_than_plain_quarter(self, monkeypatch):
|
||||
"""CJK text should estimate more tokens than the plain len // 4 heuristic.
|
||||
|
||||
CJK characters are ~2 chars/token, so the char-based estimate must not
|
||||
under-fill the budget the way ``len(text) // 4`` would.
|
||||
"""
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||
# "User prefers concise answers" rendered in CJK (Chinese) characters.
|
||||
text = "\u7528\u6237\u504f\u597d\u7b80\u6d01\u7684\u4e2d\u6587\u56de\u7b54\u5e76\u5173\u6ce8\u91d1\u878d\u9886\u57df"
|
||||
result = _count_tokens(text)
|
||||
# Each CJK char counts as ~1/2 token (vs 1/4 for the plain heuristic).
|
||||
assert result == len(text) // 2
|
||||
assert result > len(text) // 4
|
||||
|
||||
def test_cjk_estimate_combines_cjk_and_non_cjk_characters(self, monkeypatch):
|
||||
"""Mixed-language text should apply the CJK density only to CJK chars."""
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||
# ASCII words mixed with CJK (Chinese) characters: "User" + "likes" + "Python and data analysis".
|
||||
text = "User\u559c\u6b22Python\u548c\u6570\u636e\u5206\u6790"
|
||||
cjk = sum(1 for ch in text if "\u4e00" <= ch <= "\u9fff")
|
||||
|
||||
result = _count_tokens(text)
|
||||
|
||||
assert result == (len(text) - cjk) // 4 + cjk // 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# warm_tiktoken_cache
|
||||
@@ -146,3 +278,69 @@ class TestWarmTiktokenCache:
|
||||
def test_returns_false_when_tiktoken_unavailable(self, monkeypatch):
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt.TIKTOKEN_AVAILABLE", False)
|
||||
assert warm_tiktoken_cache() is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# format_memory_for_injection token_counting strategy
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFormatMemoryForInjectionTokenCounting:
|
||||
"""Verify the use_tiktoken flag is honoured end-to-end."""
|
||||
|
||||
@staticmethod
|
||||
def _sample_memory() -> dict:
|
||||
return {
|
||||
"facts": [
|
||||
{"content": "User prefers concise answers.", "category": "preference", "confidence": 0.9},
|
||||
{"content": "User works in the finance domain.", "category": "context", "confidence": 0.8},
|
||||
],
|
||||
}
|
||||
|
||||
def test_use_tiktoken_false_never_touches_tiktoken(self, monkeypatch):
|
||||
"""With use_tiktoken=False, formatting must not call tiktoken at all."""
|
||||
get_encoding_spy = mock.Mock(side_effect=AssertionError("get_encoding must not be called"))
|
||||
monkeypatch.setattr("deerflow.agents.memory.prompt.tiktoken.get_encoding", get_encoding_spy)
|
||||
|
||||
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=False)
|
||||
assert "User prefers concise answers." in result
|
||||
get_encoding_spy.assert_not_called()
|
||||
|
||||
def test_use_tiktoken_true_uses_encoding(self, monkeypatch):
|
||||
"""With use_tiktoken=True (default), the cached encoding is used for counting."""
|
||||
fake_enc = mock.Mock()
|
||||
fake_enc.encode.side_effect = lambda text: list(range(len(text)))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.memory.prompt._get_tiktoken_encoding",
|
||||
mock.Mock(return_value=fake_enc),
|
||||
)
|
||||
|
||||
result = format_memory_for_injection(self._sample_memory(), max_tokens=2000, use_tiktoken=True)
|
||||
assert "User prefers concise answers." in result
|
||||
assert fake_enc.encode.called
|
||||
|
||||
def test_empty_memory_returns_empty(self):
|
||||
assert format_memory_for_injection({}, max_tokens=2000, use_tiktoken=False) == ""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MemoryConfig.token_counting
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMemoryConfigTokenCounting:
|
||||
"""Verify the new config field defaults and validation."""
|
||||
|
||||
def test_default_is_tiktoken(self):
|
||||
"""Default must remain tiktoken so existing deployments are unaffected."""
|
||||
assert MemoryConfig().token_counting == "tiktoken"
|
||||
|
||||
def test_accepts_char(self):
|
||||
assert MemoryConfig(token_counting="char").token_counting == "char"
|
||||
|
||||
def test_rejects_invalid_value(self):
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
MemoryConfig(token_counting="invalid")
|
||||
|
||||
Generated
-2
@@ -820,7 +820,6 @@ dependencies = [
|
||||
{ name = "agent-sandbox" },
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "alembic" },
|
||||
{ name = "cryptography" },
|
||||
{ name = "ddgs" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "duckdb" },
|
||||
@@ -872,7 +871,6 @@ requires-dist = [
|
||||
{ name = "aiosqlite", specifier = ">=0.19" },
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||
{ name = "cryptography", specifier = ">=43.0.0" },
|
||||
{ name = "ddgs", specifier = ">=9.10.0" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "duckdb", specifier = ">=1.4.4" },
|
||||
|
||||
+9
-27
@@ -1024,6 +1024,15 @@ memory:
|
||||
fact_confidence_threshold: 0.7 # Minimum confidence for storing facts
|
||||
injection_enabled: true # Whether to inject memory into system prompt
|
||||
max_injection_tokens: 2000 # Maximum tokens for memory injection
|
||||
# Token counting strategy for memory-injection budgeting:
|
||||
# tiktoken (default) - accurate, but the encoding's BPE data may be
|
||||
# downloaded from a public network endpoint on first use. In
|
||||
# network-restricted environments this download can block for a long
|
||||
# time (see issues #3402 / #3429). Pre-cache the encoding or set this
|
||||
# to "char" to avoid it.
|
||||
# char - network-free CJK-aware character-based estimate; never touches
|
||||
# tiktoken. Slightly less precise budgeting, zero network I/O.
|
||||
token_counting: tiktoken
|
||||
|
||||
# ============================================================================
|
||||
# Custom Agent Management API
|
||||
@@ -1131,33 +1140,6 @@ run_events:
|
||||
max_trace_content: 10240
|
||||
track_token_usage: true
|
||||
|
||||
# ============================================================================
|
||||
# User-Owned IM Channel Connections
|
||||
# ============================================================================
|
||||
# Lets logged-in users connect their own Telegram, Slack, and Discord accounts
|
||||
# from the DeerFlow frontend while reusing the existing `channels` runtime
|
||||
# configuration below.
|
||||
#
|
||||
# Security notes:
|
||||
# - No public IP, OAuth callback URL, or provider webhook is required.
|
||||
# - Provider bot/app credentials stay under `channels.*`.
|
||||
# - `channel_connections` stores per-user bindings and one-time connect codes.
|
||||
# - Telegram uses a deep link when `bot_username` is configured.
|
||||
# - Slack and Discord use `/connect <code>` through the already-running bot.
|
||||
#
|
||||
# channel_connections:
|
||||
# enabled: false
|
||||
#
|
||||
# telegram:
|
||||
# enabled: false
|
||||
# bot_username: $TELEGRAM_BOT_USERNAME
|
||||
#
|
||||
# slack:
|
||||
# enabled: false
|
||||
#
|
||||
# discord:
|
||||
# enabled: false
|
||||
|
||||
# ============================================================================
|
||||
# IM Channels Configuration
|
||||
# ============================================================================
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
"use client";
|
||||
|
||||
import Link from "next/link";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import {
|
||||
@@ -11,24 +12,58 @@ import {
|
||||
WorkspaceHeader,
|
||||
} from "@/components/workspace/workspace-container";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { useThreads } from "@/core/threads/hooks";
|
||||
import { useInfiniteThreads } from "@/core/threads/hooks";
|
||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
||||
import { formatTimeAgo } from "@/core/utils/datetime";
|
||||
|
||||
export default function ChatsPage() {
|
||||
const { t } = useI18n();
|
||||
const { data: threads } = useThreads();
|
||||
const {
|
||||
data: infiniteThreads,
|
||||
fetchNextPage,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
} = useInfiniteThreads();
|
||||
const threads = useMemo(
|
||||
() => infiniteThreads?.pages.flat() ?? [],
|
||||
[infiniteThreads],
|
||||
);
|
||||
const [search, setSearch] = useState("");
|
||||
const isSearching = search.trim().length > 0;
|
||||
|
||||
useEffect(() => {
|
||||
document.title = `${t.pages.chats} - ${t.pages.appName}`;
|
||||
}, [t.pages.chats, t.pages.appName]);
|
||||
|
||||
const filteredThreads = useMemo(() => {
|
||||
return threads?.filter((thread) => {
|
||||
return threads.filter((thread) => {
|
||||
return titleOfThread(thread).toLowerCase().includes(search.toLowerCase());
|
||||
});
|
||||
}, [threads, search]);
|
||||
|
||||
// Sentinel-based auto load-more for the unfiltered list (issue #3482).
|
||||
// In search mode we deliberately do NOT auto-paginate, otherwise an empty
|
||||
// filtered view would keep the sentinel in the viewport and drain the
|
||||
// entire backend list one page at a time. Searching falls back to an
|
||||
// explicit button so users can still reach older conversations on demand.
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
useEffect(() => {
|
||||
const element = sentinelRef.current;
|
||||
if (!element || !hasNextPage || isSearching) {
|
||||
return;
|
||||
}
|
||||
const observer = new IntersectionObserver(
|
||||
([entry]) => {
|
||||
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||
void fetchNextPage();
|
||||
}
|
||||
},
|
||||
{ rootMargin: "200px 0px 200px 0px" },
|
||||
);
|
||||
observer.observe(element);
|
||||
return () => observer.disconnect();
|
||||
}, [fetchNextPage, hasNextPage, isFetchingNextPage, isSearching]);
|
||||
|
||||
return (
|
||||
<WorkspaceContainer>
|
||||
<WorkspaceHeader></WorkspaceHeader>
|
||||
@@ -61,6 +96,28 @@ export default function ChatsPage() {
|
||||
</div>
|
||||
</Link>
|
||||
))}
|
||||
{hasNextPage && !isSearching && (
|
||||
<div
|
||||
ref={sentinelRef}
|
||||
aria-hidden="true"
|
||||
className="h-px w-full"
|
||||
data-testid="chats-page-sentinel"
|
||||
/>
|
||||
)}
|
||||
{hasNextPage && isSearching && (
|
||||
<div className="flex justify-center p-4">
|
||||
<Button
|
||||
variant="outline"
|
||||
onClick={() => void fetchNextPage()}
|
||||
disabled={isFetchingNextPage}
|
||||
data-testid="chats-page-load-more"
|
||||
>
|
||||
{isFetchingNextPage
|
||||
? t.chats.loadingMore
|
||||
: t.chats.loadMoreToSearch}
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
</main>
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { MessageCircleIcon } from "lucide-react";
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
type ChannelProviderIconProps = SVGProps<SVGSVGElement> & {
|
||||
provider: string;
|
||||
};
|
||||
|
||||
export function ChannelProviderIcon({
|
||||
provider,
|
||||
className,
|
||||
...props
|
||||
}: ChannelProviderIconProps) {
|
||||
const normalizedProvider = provider.toLowerCase();
|
||||
|
||||
if (normalizedProvider === "telegram") {
|
||||
return (
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
aria-hidden="true"
|
||||
className={cn("size-5", className)}
|
||||
{...props}
|
||||
>
|
||||
<circle cx="12" cy="12" r="11" fill="#2AABEE" />
|
||||
<path
|
||||
fill="#FFFFFF"
|
||||
d="M17.4 7.2 15.7 16c-.1.7-.5.9-1 .6l-2.8-2.1-1.4 1.3c-.1.2-.3.3-.6.3l.2-2.9 5.3-4.8c.2-.2 0-.3-.3-.1l-6.6 4.1-2.8-.9c-.6-.2-.6-.6.1-.8l10.9-4.2c.5-.2.9.1.7.7Z"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
if (normalizedProvider === "slack") {
|
||||
return (
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
aria-hidden="true"
|
||||
className={cn("size-5", className)}
|
||||
{...props}
|
||||
>
|
||||
<rect x="10.1" y="2" width="3.8" height="8.5" rx="1.9" fill="#36C5F0" />
|
||||
<rect
|
||||
x="10.1"
|
||||
y="13.5"
|
||||
width="3.8"
|
||||
height="8.5"
|
||||
rx="1.9"
|
||||
fill="#2EB67D"
|
||||
/>
|
||||
<rect x="2" y="10.1" width="8.5" height="3.8" rx="1.9" fill="#ECB22E" />
|
||||
<rect
|
||||
x="13.5"
|
||||
y="10.1"
|
||||
width="8.5"
|
||||
height="3.8"
|
||||
rx="1.9"
|
||||
fill="#E01E5A"
|
||||
/>
|
||||
<path
|
||||
d="M8.2 2a1.9 1.9 0 0 1 1.9 1.9v1.9H8.2a1.9 1.9 0 1 1 0-3.8Z"
|
||||
fill="#36C5F0"
|
||||
/>
|
||||
<path
|
||||
d="M15.8 22a1.9 1.9 0 0 1-1.9-1.9v-1.9h1.9a1.9 1.9 0 1 1 0 3.8Z"
|
||||
fill="#2EB67D"
|
||||
/>
|
||||
<path
|
||||
d="M2 15.8a1.9 1.9 0 0 1 1.9-1.9h1.9v1.9a1.9 1.9 0 1 1-3.8 0Z"
|
||||
fill="#ECB22E"
|
||||
/>
|
||||
<path
|
||||
d="M22 8.2a1.9 1.9 0 0 1-1.9 1.9h-1.9V8.2a1.9 1.9 0 1 1 3.8 0Z"
|
||||
fill="#E01E5A"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
if (normalizedProvider === "discord") {
|
||||
return (
|
||||
<svg
|
||||
viewBox="0 0 24 24"
|
||||
aria-hidden="true"
|
||||
className={cn("size-5", className)}
|
||||
{...props}
|
||||
>
|
||||
<circle cx="12" cy="12" r="11" fill="#5865F2" />
|
||||
<path
|
||||
fill="#FFFFFF"
|
||||
d="M8.1 8.4c1.4-.6 2.7-.7 3.9-.7s2.5.1 3.9.7c1 1.5 1.5 3.1 1.4 4.8-.9.7-1.8 1.1-2.8 1.3l-.7-1.1c.4-.1.7-.3 1.1-.5-.3.1-.6.3-.9.4-.7.3-1.4.4-2 .4s-1.3-.1-2-.4c-.3-.1-.6-.2-.9-.4.3.2.7.4 1.1.5l-.7 1.1c-1-.2-1.9-.6-2.8-1.3-.1-1.7.4-3.3 1.4-4.8Zm2.1 3.9c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Zm3.6 0c.5 0 .9-.5.9-1.1s-.4-1.1-.9-1.1-.9.5-.9 1.1.4 1.1.9 1.1Z"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<MessageCircleIcon aria-hidden="true" className={cn("size-5", className)} />
|
||||
);
|
||||
}
|
||||
@@ -1,149 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { CheckIcon, LoaderCircleIcon } from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
SidebarGroup,
|
||||
SidebarGroupLabel,
|
||||
SidebarMenu,
|
||||
SidebarMenuItem,
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
import { Skeleton } from "@/components/ui/skeleton";
|
||||
import {
|
||||
useChannelProviders,
|
||||
useConnectChannelProvider,
|
||||
} from "@/core/channels/hooks";
|
||||
import {
|
||||
closeConnectWindow,
|
||||
openConnectUrl,
|
||||
prepareConnectWindow,
|
||||
} from "@/core/channels/open-connect-url";
|
||||
import type { ChannelProvider } from "@/core/channels/types";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { ChannelProviderIcon } from "./channel-provider-icon";
|
||||
|
||||
function providerCanConnect(provider: ChannelProvider): boolean {
|
||||
return (
|
||||
(provider.connectable ?? (provider.enabled && provider.configured)) &&
|
||||
provider.connection_status !== "connected"
|
||||
);
|
||||
}
|
||||
|
||||
function getProviderDisabledReason(
|
||||
provider: ChannelProvider,
|
||||
t: ReturnType<typeof useI18n>["t"],
|
||||
): string | undefined {
|
||||
if (!provider.enabled) {
|
||||
return t.channels.disabled;
|
||||
}
|
||||
if (!provider.configured) {
|
||||
return t.channels.unconfigured;
|
||||
}
|
||||
return provider.unavailable_reason ?? undefined;
|
||||
}
|
||||
|
||||
export function WorkspaceChannelsList() {
|
||||
const { open: isSidebarOpen } = useSidebar();
|
||||
const { t } = useI18n();
|
||||
const { enabled, providers, isLoading, error } = useChannelProviders();
|
||||
const connectMutation = useConnectChannelProvider();
|
||||
|
||||
if (!isSidebarOpen) {
|
||||
return null;
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<SidebarGroup className="pt-0">
|
||||
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
|
||||
<div className="space-y-2 px-2 py-1">
|
||||
<Skeleton className="h-8 w-full" />
|
||||
<Skeleton className="h-8 w-full" />
|
||||
<Skeleton className="h-8 w-full" />
|
||||
</div>
|
||||
</SidebarGroup>
|
||||
);
|
||||
}
|
||||
|
||||
if (error || !enabled || providers.length === 0) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<SidebarGroup className="pt-0">
|
||||
<SidebarGroupLabel>{t.sidebar.channels}</SidebarGroupLabel>
|
||||
<SidebarMenu>
|
||||
{providers.map((provider) => {
|
||||
const isConnected = provider.connection_status === "connected";
|
||||
const isPending =
|
||||
connectMutation.isPending &&
|
||||
connectMutation.variables === provider.provider;
|
||||
const canConnect = providerCanConnect(provider);
|
||||
|
||||
return (
|
||||
<SidebarMenuItem key={provider.provider}>
|
||||
<div className="hover:bg-sidebar-accent flex h-10 items-center gap-2 rounded-md px-2 transition-colors">
|
||||
<ChannelProviderIcon
|
||||
provider={provider.provider}
|
||||
className="size-5 shrink-0"
|
||||
/>
|
||||
<span className="min-w-0 flex-1 truncate text-sm font-medium">
|
||||
{provider.display_name}
|
||||
</span>
|
||||
<Button
|
||||
type="button"
|
||||
size="sm"
|
||||
variant={isConnected ? "outline" : "secondary"}
|
||||
className={cn(
|
||||
"h-8 w-24 px-2 text-xs",
|
||||
isConnected && "gap-1",
|
||||
)}
|
||||
disabled={!canConnect || isPending}
|
||||
title={getProviderDisabledReason(provider, t)}
|
||||
onClick={() => {
|
||||
const connectWindow =
|
||||
provider.auth_mode === "deep_link"
|
||||
? prepareConnectWindow()
|
||||
: null;
|
||||
void connectMutation
|
||||
.mutateAsync(provider.provider)
|
||||
.then((result) => {
|
||||
if (result.url) {
|
||||
openConnectUrl(result.url, connectWindow);
|
||||
return;
|
||||
}
|
||||
closeConnectWindow(connectWindow);
|
||||
toast.success(result.instruction);
|
||||
})
|
||||
.catch((error) => {
|
||||
closeConnectWindow(connectWindow);
|
||||
toast.error(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: t.channels.unavailable,
|
||||
);
|
||||
});
|
||||
}}
|
||||
>
|
||||
{isPending ? (
|
||||
<LoaderCircleIcon className="size-3.5 animate-spin" />
|
||||
) : isConnected ? (
|
||||
<CheckIcon className="size-3.5" />
|
||||
) : null}
|
||||
<span>
|
||||
{isConnected ? t.channels.connected : t.channels.connect}
|
||||
</span>
|
||||
</Button>
|
||||
</div>
|
||||
</SidebarMenuItem>
|
||||
);
|
||||
})}
|
||||
</SidebarMenu>
|
||||
</SidebarGroup>
|
||||
);
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import {
|
||||
} from "lucide-react";
|
||||
import Link from "next/link";
|
||||
import { useParams, usePathname, useRouter } from "next/navigation";
|
||||
import { useCallback, useState } from "react";
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
@@ -51,8 +51,8 @@ import {
|
||||
} from "@/core/threads/export";
|
||||
import {
|
||||
useDeleteThread,
|
||||
useInfiniteThreads,
|
||||
useRenameThread,
|
||||
useThreads,
|
||||
} from "@/core/threads/hooks";
|
||||
import type { AgentThread, AgentThreadState } from "@/core/threads/types";
|
||||
import { pathOfThread, titleOfThread } from "@/core/threads/utils";
|
||||
@@ -68,7 +68,35 @@ export function RecentChatList() {
|
||||
thread_id: string;
|
||||
agent_name?: string;
|
||||
}>();
|
||||
const { data: threads = [] } = useThreads();
|
||||
const {
|
||||
data: infiniteThreads,
|
||||
fetchNextPage,
|
||||
hasNextPage,
|
||||
isFetchingNextPage,
|
||||
} = useInfiniteThreads();
|
||||
const threads = useMemo(
|
||||
() => infiniteThreads?.pages.flat() ?? [],
|
||||
[infiniteThreads],
|
||||
);
|
||||
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
useEffect(() => {
|
||||
const element = sentinelRef.current;
|
||||
if (!element || !hasNextPage) {
|
||||
return;
|
||||
}
|
||||
const observer = new IntersectionObserver(
|
||||
([entry]) => {
|
||||
if (entry?.isIntersecting && hasNextPage && !isFetchingNextPage) {
|
||||
void fetchNextPage();
|
||||
}
|
||||
},
|
||||
{ rootMargin: "120px 0px 120px 0px" },
|
||||
);
|
||||
observer.observe(element);
|
||||
return () => observer.disconnect();
|
||||
}, [fetchNextPage, hasNextPage, isFetchingNextPage]);
|
||||
|
||||
const { mutate: deleteThread } = useDeleteThread();
|
||||
const { mutate: renameThread } = useRenameThread();
|
||||
|
||||
@@ -267,6 +295,28 @@ export function RecentChatList() {
|
||||
</SidebarMenuItem>
|
||||
);
|
||||
})}
|
||||
{hasNextPage && (
|
||||
<>
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
className="mx-2 my-1 w-[calc(100%-1rem)] justify-center text-xs"
|
||||
onClick={() => void fetchNextPage()}
|
||||
disabled={isFetchingNextPage}
|
||||
data-testid="recent-chat-list-load-more"
|
||||
>
|
||||
{isFetchingNextPage
|
||||
? t.chats.loadingMore
|
||||
: t.chats.loadOlderChats}
|
||||
</Button>
|
||||
<div
|
||||
ref={sentinelRef}
|
||||
aria-hidden="true"
|
||||
className="h-px w-full"
|
||||
data-testid="recent-chat-list-sentinel"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</SidebarMenu>
|
||||
</SidebarGroupContent>
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
AlertCircleIcon,
|
||||
CheckCircle2Icon,
|
||||
LoaderCircleIcon,
|
||||
PlugIcon,
|
||||
UnplugIcon,
|
||||
} from "lucide-react";
|
||||
import { toast } from "sonner";
|
||||
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import {
|
||||
Item,
|
||||
ItemActions,
|
||||
ItemContent,
|
||||
ItemDescription,
|
||||
ItemMedia,
|
||||
ItemTitle,
|
||||
} from "@/components/ui/item";
|
||||
import {
|
||||
useChannelConnections,
|
||||
useChannelProviders,
|
||||
useConnectChannelProvider,
|
||||
useDisconnectChannelConnection,
|
||||
} from "@/core/channels/hooks";
|
||||
import {
|
||||
closeConnectWindow,
|
||||
openConnectUrl,
|
||||
prepareConnectWindow,
|
||||
} from "@/core/channels/open-connect-url";
|
||||
import type { ChannelConnection, ChannelProvider } from "@/core/channels/types";
|
||||
import { useI18n } from "@/core/i18n/hooks";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
import { ChannelProviderIcon } from "../channels/channel-provider-icon";
|
||||
|
||||
import { SettingsSection } from "./settings-section";
|
||||
|
||||
function getProviderDescription(
|
||||
provider: ChannelProvider,
|
||||
descriptions: Record<string, string>,
|
||||
): string {
|
||||
return descriptions[provider.provider] ?? provider.display_name;
|
||||
}
|
||||
|
||||
function getConnectionLabel(connection: ChannelConnection): string | null {
|
||||
const account = connection.external_account_name;
|
||||
const workspace = connection.workspace_name;
|
||||
if (account && workspace) {
|
||||
return `${account} · ${workspace}`;
|
||||
}
|
||||
return account ?? workspace ?? connection.external_account_id ?? null;
|
||||
}
|
||||
|
||||
function getStatusLabel(
|
||||
provider: ChannelProvider,
|
||||
connection: ChannelConnection | undefined,
|
||||
t: ReturnType<typeof useI18n>["t"],
|
||||
): string {
|
||||
if (!provider.enabled) {
|
||||
return t.channels.disabled;
|
||||
}
|
||||
if (!provider.configured) {
|
||||
return t.channels.unconfigured;
|
||||
}
|
||||
if (provider.unavailable_reason) {
|
||||
return t.channels.unavailableShort;
|
||||
}
|
||||
const status = connection?.status ?? provider.connection_status;
|
||||
if (status === "connected") {
|
||||
return t.channels.connected;
|
||||
}
|
||||
if (status === "pending") {
|
||||
return t.channels.pending;
|
||||
}
|
||||
if (status === "revoked") {
|
||||
return t.channels.revoked;
|
||||
}
|
||||
return t.channels.notConnected;
|
||||
}
|
||||
|
||||
function getProviderDisabledReason(
|
||||
provider: ChannelProvider,
|
||||
t: ReturnType<typeof useI18n>["t"],
|
||||
): string | undefined {
|
||||
if (!provider.enabled) {
|
||||
return t.channels.disabled;
|
||||
}
|
||||
if (!provider.configured) {
|
||||
return t.channels.unconfigured;
|
||||
}
|
||||
return provider.unavailable_reason ?? undefined;
|
||||
}
|
||||
|
||||
function ChannelProviderItem({
|
||||
provider,
|
||||
connection,
|
||||
}: {
|
||||
provider: ChannelProvider;
|
||||
connection?: ChannelConnection;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
const connectMutation = useConnectChannelProvider();
|
||||
const disconnectMutation = useDisconnectChannelConnection();
|
||||
const isConnected = connection?.status === "connected";
|
||||
const canConnect =
|
||||
(provider.connectable ?? (provider.enabled && provider.configured)) &&
|
||||
!isConnected;
|
||||
const isConnecting =
|
||||
connectMutation.isPending &&
|
||||
connectMutation.variables === provider.provider;
|
||||
const isDisconnecting =
|
||||
disconnectMutation.isPending &&
|
||||
disconnectMutation.variables === connection?.id;
|
||||
const connectionLabel = connection ? getConnectionLabel(connection) : null;
|
||||
const statusLabel = getStatusLabel(provider, connection, t);
|
||||
const unavailableReason = getProviderDisabledReason(provider, t);
|
||||
|
||||
return (
|
||||
<Item variant="outline" className="w-full items-start">
|
||||
<ItemMedia variant="icon" className="bg-background">
|
||||
<ChannelProviderIcon provider={provider.provider} className="size-5" />
|
||||
</ItemMedia>
|
||||
<ItemContent className="min-w-0">
|
||||
<ItemTitle className="w-full">
|
||||
<span className="truncate">{provider.display_name}</span>
|
||||
<Badge
|
||||
variant={isConnected ? "default" : "outline"}
|
||||
className={cn(!isConnected && "text-muted-foreground")}
|
||||
>
|
||||
{isConnected ? <CheckCircle2Icon /> : <AlertCircleIcon />}
|
||||
{statusLabel}
|
||||
</Badge>
|
||||
</ItemTitle>
|
||||
<ItemDescription className="line-clamp-none">
|
||||
{getProviderDescription(provider, t.channels.descriptions)}
|
||||
{connectionLabel ? ` ${t.channels.connectedAs(connectionLabel)}` : ""}
|
||||
{!isConnected && provider.unavailable_reason
|
||||
? ` ${provider.unavailable_reason}`
|
||||
: ""}
|
||||
</ItemDescription>
|
||||
</ItemContent>
|
||||
<ItemActions className="ml-auto">
|
||||
{isConnected && connection ? (
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
disabled={isDisconnecting}
|
||||
onClick={() => disconnectMutation.mutate(connection.id)}
|
||||
>
|
||||
{isDisconnecting ? (
|
||||
<LoaderCircleIcon className="animate-spin" />
|
||||
) : (
|
||||
<UnplugIcon />
|
||||
)}
|
||||
{t.channels.disconnect}
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
type="button"
|
||||
size="sm"
|
||||
disabled={!canConnect || isConnecting}
|
||||
title={unavailableReason}
|
||||
onClick={() => {
|
||||
const connectWindow =
|
||||
provider.auth_mode === "deep_link"
|
||||
? prepareConnectWindow()
|
||||
: null;
|
||||
void connectMutation
|
||||
.mutateAsync(provider.provider)
|
||||
.then((result) => {
|
||||
if (result.url) {
|
||||
openConnectUrl(result.url, connectWindow);
|
||||
return;
|
||||
}
|
||||
closeConnectWindow(connectWindow);
|
||||
toast.success(result.instruction);
|
||||
})
|
||||
.catch((error) => {
|
||||
closeConnectWindow(connectWindow);
|
||||
toast.error(
|
||||
error instanceof Error
|
||||
? error.message
|
||||
: t.channels.unavailable,
|
||||
);
|
||||
});
|
||||
}}
|
||||
>
|
||||
{isConnecting ? (
|
||||
<LoaderCircleIcon className="animate-spin" />
|
||||
) : (
|
||||
<PlugIcon />
|
||||
)}
|
||||
{connection?.status === "revoked"
|
||||
? t.channels.reconnect
|
||||
: t.channels.connect}
|
||||
</Button>
|
||||
)}
|
||||
</ItemActions>
|
||||
</Item>
|
||||
);
|
||||
}
|
||||
|
||||
export function ChannelsSettingsPage() {
|
||||
const { t } = useI18n();
|
||||
const {
|
||||
enabled,
|
||||
providers,
|
||||
isLoading: providersLoading,
|
||||
error: providersError,
|
||||
} = useChannelProviders();
|
||||
const {
|
||||
connections,
|
||||
isLoading: connectionsLoading,
|
||||
error: connectionsError,
|
||||
} = useChannelConnections();
|
||||
const isLoading = providersLoading || connectionsLoading;
|
||||
const error = providersError ?? connectionsError;
|
||||
|
||||
const connectionByProvider = new Map<string, ChannelConnection>();
|
||||
for (const connection of connections) {
|
||||
const existing = connectionByProvider.get(connection.provider);
|
||||
if (!existing || connection.status === "connected") {
|
||||
connectionByProvider.set(connection.provider, connection);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsSection
|
||||
title={t.settings.channels.title}
|
||||
description={t.settings.channels.description}
|
||||
>
|
||||
{isLoading ? (
|
||||
<div className="text-muted-foreground text-sm">{t.common.loading}</div>
|
||||
) : error ? (
|
||||
<div className="text-destructive text-sm">{t.channels.unavailable}</div>
|
||||
) : !enabled ? (
|
||||
<div className="text-muted-foreground text-sm">
|
||||
{t.settings.channels.disabled}
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex w-full flex-col gap-4">
|
||||
{providers.map((provider) => (
|
||||
<ChannelProviderItem
|
||||
key={provider.provider}
|
||||
provider={provider}
|
||||
connection={connectionByProvider.get(provider.provider)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</SettingsSection>
|
||||
);
|
||||
}
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import {
|
||||
BellIcon,
|
||||
CableIcon,
|
||||
InfoIcon,
|
||||
BrainIcon,
|
||||
PaletteIcon,
|
||||
@@ -22,7 +21,6 @@ import { ScrollArea } from "@/components/ui/scroll-area";
|
||||
import { AboutSettingsPage } from "@/components/workspace/settings/about-settings-page";
|
||||
import { AccountSettingsPage } from "@/components/workspace/settings/account-settings-page";
|
||||
import { AppearanceSettingsPage } from "@/components/workspace/settings/appearance-settings-page";
|
||||
import { ChannelsSettingsPage } from "@/components/workspace/settings/channels-settings-page";
|
||||
import { MemorySettingsPage } from "@/components/workspace/settings/memory-settings-page";
|
||||
import { NotificationSettingsPage } from "@/components/workspace/settings/notification-settings-page";
|
||||
import { SkillSettingsPage } from "@/components/workspace/settings/skill-settings-page";
|
||||
@@ -33,7 +31,6 @@ import { cn } from "@/lib/utils";
|
||||
type SettingsSection =
|
||||
| "account"
|
||||
| "appearance"
|
||||
| "channels"
|
||||
| "memory"
|
||||
| "tools"
|
||||
| "skills"
|
||||
@@ -75,11 +72,6 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
||||
label: t.settings.sections.notification,
|
||||
icon: BellIcon,
|
||||
},
|
||||
{
|
||||
id: "channels",
|
||||
label: t.settings.sections.channels,
|
||||
icon: CableIcon,
|
||||
},
|
||||
{
|
||||
id: "memory",
|
||||
label: t.settings.sections.memory,
|
||||
@@ -92,7 +84,6 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
||||
[
|
||||
t.settings.sections.account,
|
||||
t.settings.sections.appearance,
|
||||
t.settings.sections.channels,
|
||||
t.settings.sections.memory,
|
||||
t.settings.sections.tools,
|
||||
t.settings.sections.skills,
|
||||
@@ -152,7 +143,6 @@ export function SettingsDialog(props: SettingsDialogProps) {
|
||||
/>
|
||||
)}
|
||||
{activeSection === "notification" && <NotificationSettingsPage />}
|
||||
{activeSection === "channels" && <ChannelsSettingsPage />}
|
||||
{activeSection === "about" && <AboutSettingsPage />}
|
||||
</div>
|
||||
</ScrollArea>
|
||||
|
||||
@@ -9,7 +9,6 @@ import {
|
||||
useSidebar,
|
||||
} from "@/components/ui/sidebar";
|
||||
|
||||
import { WorkspaceChannelsList } from "./channels/workspace-channels-list";
|
||||
import { RecentChatList } from "./recent-chat-list";
|
||||
import { WorkspaceHeader } from "./workspace-header";
|
||||
import { WorkspaceNavChatList } from "./workspace-nav-chat-list";
|
||||
@@ -27,7 +26,6 @@ export function WorkspaceSidebar({
|
||||
</SidebarHeader>
|
||||
<SidebarContent>
|
||||
<WorkspaceNavChatList />
|
||||
<WorkspaceChannelsList />
|
||||
{isSidebarOpen && <RecentChatList />}
|
||||
</SidebarContent>
|
||||
<SidebarFooter>
|
||||
|
||||
@@ -1,78 +0,0 @@
|
||||
import { fetch } from "@/core/api/fetcher";
|
||||
import { getBackendBaseURL } from "@/core/config";
|
||||
|
||||
import type {
|
||||
ChannelConnectResponse,
|
||||
ChannelConnection,
|
||||
ChannelConnectionsResponse,
|
||||
ChannelProviderId,
|
||||
ChannelProvidersResponse,
|
||||
} from "./types";
|
||||
|
||||
function channelsUrl(path: string): string {
|
||||
return `${getBackendBaseURL()}/api/channels${path}`;
|
||||
}
|
||||
|
||||
async function throwChannelApiError(
|
||||
response: Response,
|
||||
fallback: string,
|
||||
): Promise<never> {
|
||||
const body = (await response.json().catch(() => ({}))) as {
|
||||
detail?: unknown;
|
||||
};
|
||||
throw new Error(typeof body.detail === "string" ? body.detail : fallback);
|
||||
}
|
||||
|
||||
export async function listChannelProviders(): Promise<ChannelProvidersResponse> {
|
||||
const response = await fetch(channelsUrl("/providers"));
|
||||
if (!response.ok) {
|
||||
await throwChannelApiError(
|
||||
response,
|
||||
`Failed to load channel providers: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return response.json() as Promise<ChannelProvidersResponse>;
|
||||
}
|
||||
|
||||
export async function listChannelConnections(): Promise<ChannelConnection[]> {
|
||||
const response = await fetch(channelsUrl("/connections"));
|
||||
if (!response.ok) {
|
||||
await throwChannelApiError(
|
||||
response,
|
||||
`Failed to load channel connections: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
const data = (await response.json()) as ChannelConnectionsResponse;
|
||||
return data.connections;
|
||||
}
|
||||
|
||||
export async function connectChannelProvider(
|
||||
provider: ChannelProviderId,
|
||||
): Promise<ChannelConnectResponse> {
|
||||
const response = await fetch(
|
||||
channelsUrl(`/${encodeURIComponent(provider)}/connect`),
|
||||
{ method: "POST" },
|
||||
);
|
||||
if (!response.ok) {
|
||||
await throwChannelApiError(
|
||||
response,
|
||||
`Failed to connect ${provider}: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
return response.json() as Promise<ChannelConnectResponse>;
|
||||
}
|
||||
|
||||
export async function disconnectChannelConnection(
|
||||
connectionId: string,
|
||||
): Promise<void> {
|
||||
const response = await fetch(
|
||||
channelsUrl(`/connections/${encodeURIComponent(connectionId)}`),
|
||||
{ method: "DELETE" },
|
||||
);
|
||||
if (!response.ok) {
|
||||
await throwChannelApiError(
|
||||
response,
|
||||
`Failed to disconnect channel: ${response.statusText}`,
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
|
||||
import {
|
||||
connectChannelProvider,
|
||||
disconnectChannelConnection,
|
||||
listChannelConnections,
|
||||
listChannelProviders,
|
||||
} from "./api";
|
||||
import type { ChannelProviderId } from "./types";
|
||||
|
||||
export const channelProviderQueryKey = ["channelProviders"] as const;
|
||||
export const channelConnectionsQueryKey = ["channelConnections"] as const;
|
||||
|
||||
export function useChannelProviders() {
|
||||
const { data, isLoading, error } = useQuery({
|
||||
queryKey: channelProviderQueryKey,
|
||||
queryFn: () => listChannelProviders(),
|
||||
});
|
||||
return {
|
||||
enabled: data?.enabled ?? false,
|
||||
providers: data?.providers ?? [],
|
||||
isLoading,
|
||||
error,
|
||||
};
|
||||
}
|
||||
|
||||
export function useChannelConnections() {
|
||||
const { data, isLoading, error } = useQuery({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
queryFn: () => listChannelConnections(),
|
||||
});
|
||||
return { connections: data ?? [], isLoading, error };
|
||||
}
|
||||
|
||||
export function useConnectChannelProvider() {
|
||||
const queryClient = useQueryClient();
|
||||
return useMutation({
|
||||
mutationFn: (provider: ChannelProviderId) =>
|
||||
connectChannelProvider(provider),
|
||||
onSuccess: () => {
|
||||
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
export function useDisconnectChannelConnection() {
|
||||
const queryClient = useQueryClient();
|
||||
return useMutation({
|
||||
mutationFn: (connectionId: string) =>
|
||||
disconnectChannelConnection(connectionId),
|
||||
onSuccess: () => {
|
||||
void queryClient.invalidateQueries({ queryKey: channelProviderQueryKey });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: channelConnectionsQueryKey,
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1,27 +0,0 @@
|
||||
export type ChannelConnectWindow = Window | null;
|
||||
|
||||
export function prepareConnectWindow(): ChannelConnectWindow {
|
||||
const opened = window.open("about:blank", "_blank");
|
||||
if (opened) {
|
||||
opened.opener = null;
|
||||
}
|
||||
return opened;
|
||||
}
|
||||
|
||||
export function openConnectUrl(
|
||||
url: string,
|
||||
connectWindow: ChannelConnectWindow = prepareConnectWindow(),
|
||||
) {
|
||||
if (connectWindow && !connectWindow.closed) {
|
||||
connectWindow.location.replace(url);
|
||||
return;
|
||||
}
|
||||
|
||||
window.location.assign(url);
|
||||
}
|
||||
|
||||
export function closeConnectWindow(connectWindow: ChannelConnectWindow) {
|
||||
if (connectWindow && !connectWindow.closed) {
|
||||
connectWindow.close();
|
||||
}
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
export type ChannelProviderId = "telegram" | "slack" | "discord" | string;
|
||||
|
||||
export interface ChannelProvider {
|
||||
provider: ChannelProviderId;
|
||||
display_name: string;
|
||||
enabled: boolean;
|
||||
configured: boolean;
|
||||
connectable?: boolean;
|
||||
unavailable_reason?: string | null;
|
||||
auth_mode: string;
|
||||
connection_status: string;
|
||||
}
|
||||
|
||||
export interface ChannelProvidersResponse {
|
||||
enabled: boolean;
|
||||
providers: ChannelProvider[];
|
||||
}
|
||||
|
||||
export interface ChannelConnection {
|
||||
id: string;
|
||||
provider: ChannelProviderId;
|
||||
status: string;
|
||||
external_account_id?: string | null;
|
||||
external_account_name?: string | null;
|
||||
workspace_id?: string | null;
|
||||
workspace_name?: string | null;
|
||||
scopes: string[];
|
||||
metadata: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface ChannelConnectionsResponse {
|
||||
connections: ChannelConnection[];
|
||||
}
|
||||
|
||||
export interface ChannelConnectResponse {
|
||||
provider: ChannelProviderId;
|
||||
mode: string;
|
||||
url?: string | null;
|
||||
code: string;
|
||||
instruction: string;
|
||||
expires_in: number;
|
||||
}
|
||||
@@ -170,7 +170,6 @@ export const enUS: Translations = {
|
||||
sidebar: {
|
||||
newChat: "New chat",
|
||||
chats: "Chats",
|
||||
channels: "Channels",
|
||||
recentChats: "Recent chats",
|
||||
demoChats: "Demo chats",
|
||||
agents: "Agents",
|
||||
@@ -253,28 +252,9 @@ export const enUS: Translations = {
|
||||
// Chats
|
||||
chats: {
|
||||
searchChats: "Search chats",
|
||||
},
|
||||
|
||||
// Channels
|
||||
channels: {
|
||||
title: "Channels",
|
||||
connect: "Connect",
|
||||
reconnect: "Reconnect",
|
||||
disconnect: "Disconnect",
|
||||
connected: "Connected",
|
||||
notConnected: "Not connected",
|
||||
pending: "Pending",
|
||||
revoked: "Disconnected",
|
||||
disabled: "Disabled",
|
||||
unconfigured: "Not configured",
|
||||
unavailable: "Channel connections are unavailable right now.",
|
||||
unavailableShort: "Unavailable",
|
||||
descriptions: {
|
||||
telegram: "Telegram direct messages through your DeerFlow bot.",
|
||||
slack: "Slack workspace messages and mentions.",
|
||||
discord: "Discord server messages through your DeerFlow bot.",
|
||||
},
|
||||
connectedAs: (name: string) => `Connected as ${name}.`,
|
||||
loadMoreToSearch: "Load more to search older conversations",
|
||||
loadingMore: "Loading more...",
|
||||
loadOlderChats: "Load older chats",
|
||||
},
|
||||
|
||||
// Page titles (document title)
|
||||
@@ -377,7 +357,6 @@ export const enUS: Translations = {
|
||||
sections: {
|
||||
account: "Account",
|
||||
appearance: "Appearance",
|
||||
channels: "Channels",
|
||||
memory: "Memory",
|
||||
tools: "Tools",
|
||||
skills: "Skills",
|
||||
@@ -480,13 +459,6 @@ export const enUS: Translations = {
|
||||
title: "Tools",
|
||||
description: "Manage the configuration and enabled status of MCP tools.",
|
||||
},
|
||||
channels: {
|
||||
title: "Channels",
|
||||
description:
|
||||
"Connect IM accounts that can send messages to DeerFlow from outside the browser.",
|
||||
disabled:
|
||||
"Channel connections are not enabled on this server. Ask an administrator to enable channel_connections.",
|
||||
},
|
||||
skills: {
|
||||
title: "Agent Skills",
|
||||
description:
|
||||
|
||||
@@ -117,7 +117,6 @@ export interface Translations {
|
||||
chats: string;
|
||||
demoChats: string;
|
||||
agents: string;
|
||||
channels: string;
|
||||
};
|
||||
|
||||
// Agents
|
||||
@@ -184,24 +183,9 @@ export interface Translations {
|
||||
// Chats
|
||||
chats: {
|
||||
searchChats: string;
|
||||
};
|
||||
|
||||
// Channels
|
||||
channels: {
|
||||
title: string;
|
||||
connect: string;
|
||||
reconnect: string;
|
||||
disconnect: string;
|
||||
connected: string;
|
||||
notConnected: string;
|
||||
pending: string;
|
||||
revoked: string;
|
||||
disabled: string;
|
||||
unconfigured: string;
|
||||
unavailable: string;
|
||||
unavailableShort: string;
|
||||
descriptions: Record<string, string>;
|
||||
connectedAs: (name: string) => string;
|
||||
loadMoreToSearch: string;
|
||||
loadingMore: string;
|
||||
loadOlderChats: string;
|
||||
};
|
||||
|
||||
// Page titles (document title)
|
||||
@@ -300,7 +284,6 @@ export interface Translations {
|
||||
sections: {
|
||||
account: string;
|
||||
appearance: string;
|
||||
channels: string;
|
||||
memory: string;
|
||||
tools: string;
|
||||
skills: string;
|
||||
@@ -396,11 +379,6 @@ export interface Translations {
|
||||
title: string;
|
||||
description: string;
|
||||
};
|
||||
channels: {
|
||||
title: string;
|
||||
description: string;
|
||||
disabled: string;
|
||||
};
|
||||
skills: {
|
||||
title: string;
|
||||
description: string;
|
||||
|
||||
@@ -164,7 +164,6 @@ export const zhCN: Translations = {
|
||||
sidebar: {
|
||||
newChat: "新对话",
|
||||
chats: "对话",
|
||||
channels: "渠道",
|
||||
recentChats: "最近的对话",
|
||||
demoChats: "演示对话",
|
||||
agents: "智能体",
|
||||
@@ -241,28 +240,9 @@ export const zhCN: Translations = {
|
||||
// Chats
|
||||
chats: {
|
||||
searchChats: "搜索对话",
|
||||
},
|
||||
|
||||
// Channels
|
||||
channels: {
|
||||
title: "渠道",
|
||||
connect: "连接",
|
||||
reconnect: "重新连接",
|
||||
disconnect: "断开连接",
|
||||
connected: "已连接",
|
||||
notConnected: "未连接",
|
||||
pending: "待完成",
|
||||
revoked: "已断开",
|
||||
disabled: "已停用",
|
||||
unconfigured: "未配置",
|
||||
unavailable: "当前无法使用渠道连接。",
|
||||
unavailableShort: "不可用",
|
||||
descriptions: {
|
||||
telegram: "通过 DeerFlow Bot 接收 Telegram 私聊消息。",
|
||||
slack: "接收 Slack 工作区消息和提及。",
|
||||
discord: "通过 DeerFlow Bot 接收 Discord 服务器消息。",
|
||||
},
|
||||
connectedAs: (name: string) => `已连接为 ${name}。`,
|
||||
loadMoreToSearch: "加载更多以搜索更早的对话",
|
||||
loadingMore: "正在加载...",
|
||||
loadOlderChats: "加载更早的对话",
|
||||
},
|
||||
|
||||
// Page titles (document title)
|
||||
@@ -361,7 +341,6 @@ export const zhCN: Translations = {
|
||||
sections: {
|
||||
account: "账号",
|
||||
appearance: "外观",
|
||||
channels: "渠道",
|
||||
memory: "记忆",
|
||||
tools: "工具",
|
||||
skills: "技能",
|
||||
@@ -461,12 +440,6 @@ export const zhCN: Translations = {
|
||||
title: "工具",
|
||||
description: "管理 MCP 工具的配置和启用状态。",
|
||||
},
|
||||
channels: {
|
||||
title: "渠道",
|
||||
description: "连接可在浏览器外向 DeerFlow 发送消息的即时通讯账号。",
|
||||
disabled:
|
||||
"当前服务器未启用渠道连接。请联系管理员开启 channel_connections。",
|
||||
},
|
||||
skills: {
|
||||
title: "技能",
|
||||
description: "管理 Agent Skill 配置和启用状态。",
|
||||
|
||||
@@ -3,6 +3,8 @@ import type { ThreadsClient } from "@langchain/langgraph-sdk/client";
|
||||
import { useStream } from "@langchain/langgraph-sdk/react";
|
||||
import {
|
||||
type QueryClient,
|
||||
type InfiniteData,
|
||||
useInfiniteQuery,
|
||||
useMutation,
|
||||
useQuery,
|
||||
useQueryClient,
|
||||
@@ -311,6 +313,56 @@ export function upsertThreadInSearchCache(
|
||||
);
|
||||
}
|
||||
|
||||
export function upsertThreadInInfiniteCache(
|
||||
queryClient: QueryClient,
|
||||
thread: AgentThread,
|
||||
) {
|
||||
queryClient.setQueriesData(
|
||||
{
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
exact: false,
|
||||
},
|
||||
(oldData: InfiniteData<AgentThread[]> | undefined) => {
|
||||
if (!oldData) {
|
||||
return oldData;
|
||||
}
|
||||
|
||||
const merged = oldData.pages.map((page) =>
|
||||
page.map((t) =>
|
||||
t.thread_id === thread.thread_id
|
||||
? {
|
||||
...thread,
|
||||
...t,
|
||||
metadata: {
|
||||
...(thread.metadata ?? {}),
|
||||
...(t.metadata ?? {}),
|
||||
},
|
||||
values: {
|
||||
...thread.values,
|
||||
...t.values,
|
||||
},
|
||||
}
|
||||
: t,
|
||||
),
|
||||
);
|
||||
|
||||
const exists = merged.some((page) =>
|
||||
page.some((t) => t.thread_id === thread.thread_id),
|
||||
);
|
||||
if (exists) {
|
||||
return { ...oldData, pages: merged };
|
||||
}
|
||||
|
||||
const firstPage = merged[0] ?? [];
|
||||
const restPages = merged.slice(1);
|
||||
return {
|
||||
...oldData,
|
||||
pages: [[thread, ...firstPage], ...restPages],
|
||||
};
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
function getStreamErrorMessage(error: unknown): string {
|
||||
if (typeof error === "string" && error.trim()) {
|
||||
return error;
|
||||
@@ -417,6 +469,19 @@ export function useThreadStream({
|
||||
},
|
||||
interrupts: {},
|
||||
});
|
||||
upsertThreadInInfiniteCache(queryClient, {
|
||||
thread_id: meta.thread_id,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
metadata: context.agent_name ? { agent_name: context.agent_name } : {},
|
||||
status: "busy",
|
||||
values: {
|
||||
title: t.pages.newChat,
|
||||
messages: [],
|
||||
artifacts: [],
|
||||
},
|
||||
interrupts: {},
|
||||
});
|
||||
if (context.agent_name && !isMock) {
|
||||
void getAPIClient()
|
||||
.threads.update(meta.thread_id, {
|
||||
@@ -488,6 +553,27 @@ export function useThreadStream({
|
||||
});
|
||||
},
|
||||
);
|
||||
const nextTitle: string = update.title;
|
||||
void queryClient.setQueriesData(
|
||||
{
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
exact: false,
|
||||
},
|
||||
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||
mapInfiniteThreadsCache(
|
||||
oldData,
|
||||
(t): AgentThread =>
|
||||
t.thread_id === threadIdRef.current
|
||||
? {
|
||||
...t,
|
||||
values: {
|
||||
...t.values,
|
||||
title: nextTitle,
|
||||
},
|
||||
}
|
||||
: t,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -542,6 +628,9 @@ export function useThreadStream({
|
||||
.filter((id): id is string => Boolean(id)),
|
||||
);
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
});
|
||||
if (threadIdRef.current && !isMock) {
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
|
||||
@@ -801,6 +890,9 @@ export function useThreadStream({
|
||||
},
|
||||
);
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
});
|
||||
} catch (error) {
|
||||
setOptimisticMessages([]);
|
||||
setIsUploading(false);
|
||||
@@ -1100,6 +1192,86 @@ export function useThreads(
|
||||
});
|
||||
}
|
||||
|
||||
export const INFINITE_THREADS_PAGE_SIZE = 50;
|
||||
|
||||
export const INFINITE_THREADS_QUERY_KEY_PREFIX = [
|
||||
"threads",
|
||||
"searchInfinite",
|
||||
] as const;
|
||||
|
||||
type InfiniteThreadsParams = Omit<
|
||||
Parameters<ThreadsClient["search"]>[0],
|
||||
"limit" | "offset"
|
||||
>;
|
||||
|
||||
export function getInfiniteThreadsNextPageParam(
|
||||
lastPage: AgentThread[],
|
||||
allPages: AgentThread[][],
|
||||
pageSize: number = INFINITE_THREADS_PAGE_SIZE,
|
||||
): number | undefined {
|
||||
if (lastPage.length < pageSize) {
|
||||
return undefined;
|
||||
}
|
||||
return allPages.reduce((sum, page) => sum + page.length, 0);
|
||||
}
|
||||
|
||||
export function mapInfiniteThreadsCache(
|
||||
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||
mapper: (thread: AgentThread) => AgentThread,
|
||||
): InfiniteData<AgentThread[]> | undefined {
|
||||
if (!oldData) {
|
||||
return oldData;
|
||||
}
|
||||
return {
|
||||
...oldData,
|
||||
pages: oldData.pages.map((page) => page.map(mapper)),
|
||||
};
|
||||
}
|
||||
|
||||
export function filterInfiniteThreadsCache(
|
||||
oldData: InfiniteData<AgentThread[]> | undefined,
|
||||
predicate: (thread: AgentThread) => boolean,
|
||||
): InfiniteData<AgentThread[]> | undefined {
|
||||
if (!oldData) {
|
||||
return oldData;
|
||||
}
|
||||
return {
|
||||
...oldData,
|
||||
pages: oldData.pages.map((page) => page.filter(predicate)),
|
||||
};
|
||||
}
|
||||
|
||||
export function useInfiniteThreads(
|
||||
params: InfiniteThreadsParams = {
|
||||
sortBy: "updated_at",
|
||||
sortOrder: "desc",
|
||||
select: ["thread_id", "updated_at", "values", "metadata"],
|
||||
},
|
||||
) {
|
||||
const apiClient = getAPIClient();
|
||||
return useInfiniteQuery<
|
||||
AgentThread[],
|
||||
Error,
|
||||
InfiniteData<AgentThread[]>,
|
||||
readonly unknown[],
|
||||
number
|
||||
>({
|
||||
queryKey: [...INFINITE_THREADS_QUERY_KEY_PREFIX, params],
|
||||
initialPageParam: 0,
|
||||
queryFn: async ({ pageParam }) => {
|
||||
const response = (await apiClient.threads.search<AgentThreadState>({
|
||||
...params,
|
||||
limit: INFINITE_THREADS_PAGE_SIZE,
|
||||
offset: pageParam,
|
||||
})) as AgentThread[];
|
||||
return response;
|
||||
},
|
||||
getNextPageParam: (lastPage, allPages) =>
|
||||
getInfiniteThreadsNextPageParam(lastPage, allPages),
|
||||
refetchOnWindowFocus: false,
|
||||
});
|
||||
}
|
||||
|
||||
export function useThreadRuns(
|
||||
threadId?: string,
|
||||
{ enabled = true }: { enabled?: boolean } = {},
|
||||
@@ -1183,9 +1355,21 @@ export function useDeleteThread() {
|
||||
return oldData.filter((t) => t.thread_id !== threadId);
|
||||
},
|
||||
);
|
||||
queryClient.setQueriesData(
|
||||
{
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
exact: false,
|
||||
},
|
||||
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||
filterInfiniteThreadsCache(oldData, (t) => t.thread_id !== threadId),
|
||||
);
|
||||
},
|
||||
|
||||
onSettled() {
|
||||
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
|
||||
void queryClient.invalidateQueries({
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
});
|
||||
},
|
||||
});
|
||||
}
|
||||
@@ -1226,6 +1410,24 @@ export function useRenameThread() {
|
||||
});
|
||||
},
|
||||
);
|
||||
queryClient.setQueriesData(
|
||||
{
|
||||
queryKey: INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
exact: false,
|
||||
},
|
||||
(oldData: InfiniteData<AgentThread[]> | undefined) =>
|
||||
mapInfiniteThreadsCache(oldData, (t) =>
|
||||
t.thread_id === threadId
|
||||
? {
|
||||
...t,
|
||||
values: {
|
||||
...t.values,
|
||||
title,
|
||||
},
|
||||
}
|
||||
: t,
|
||||
),
|
||||
);
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
import { expect, test, type Page } from "@playwright/test";
|
||||
|
||||
import { mockLangGraphAPI } from "./utils/mock-api";
|
||||
|
||||
function mockChannelsAPI(page: Page) {
|
||||
void page.route("**/api/channels/providers", (route) => {
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
enabled: true,
|
||||
providers: [
|
||||
{
|
||||
provider: "telegram",
|
||||
display_name: "Telegram",
|
||||
enabled: true,
|
||||
configured: true,
|
||||
auth_mode: "deep_link",
|
||||
connection_status: "not_connected",
|
||||
},
|
||||
{
|
||||
provider: "slack",
|
||||
display_name: "Slack",
|
||||
enabled: true,
|
||||
configured: true,
|
||||
auth_mode: "binding_code",
|
||||
connection_status: "not_connected",
|
||||
},
|
||||
{
|
||||
provider: "discord",
|
||||
display_name: "Discord",
|
||||
enabled: true,
|
||||
configured: true,
|
||||
auth_mode: "binding_code",
|
||||
connection_status: "not_connected",
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
void page.route("**/api/channels/connections", (route) => {
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ connections: [] }),
|
||||
});
|
||||
});
|
||||
|
||||
void page.route("**/api/channels/slack/connect", (route) => {
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
provider: "slack",
|
||||
mode: "binding_code",
|
||||
url: null,
|
||||
code: "abc123",
|
||||
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
|
||||
expires_in: 600,
|
||||
}),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
test.describe("IM channels", () => {
|
||||
test("sidebar and settings expose channel connections", async ({ page }) => {
|
||||
mockLangGraphAPI(page);
|
||||
mockChannelsAPI(page);
|
||||
|
||||
await page.goto("/workspace/chats/new");
|
||||
|
||||
const sidebar = page.locator("[data-sidebar='sidebar']");
|
||||
await expect(sidebar.getByText("Channels")).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
await expect(sidebar.getByText("Telegram")).toBeVisible();
|
||||
await expect(sidebar.getByText("Slack")).toBeVisible();
|
||||
await expect(sidebar.getByText("Discord")).toBeVisible();
|
||||
await expect(sidebar.getByRole("button", { name: "Connect" })).toHaveCount(
|
||||
3,
|
||||
);
|
||||
|
||||
await sidebar.getByRole("button", { name: /Settings and more/ }).click();
|
||||
await page.getByRole("menuitem", { name: "Settings" }).click();
|
||||
await page.getByRole("button", { name: "Channels" }).click();
|
||||
|
||||
await expect(page.getByText("Telegram direct messages")).toBeVisible();
|
||||
await expect(page.getByText("Slack workspace messages")).toBeVisible();
|
||||
await expect(page.getByText("Discord server messages")).toBeVisible();
|
||||
|
||||
const dialog = page.getByRole("dialog", { name: "Settings" });
|
||||
const connectButtons = dialog.getByRole("button", { name: "Connect" });
|
||||
await expect(connectButtons).toHaveCount(3);
|
||||
|
||||
await connectButtons.nth(1).click();
|
||||
await expect(page).toHaveURL(/\/workspace\/chats\/new/);
|
||||
await expect(
|
||||
page.getByText("Send /connect abc123 to the DeerFlow Slack bot."),
|
||||
).toBeVisible();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,123 @@
|
||||
import { expect, test } from "@playwright/test";
|
||||
|
||||
import { mockLangGraphAPI } from "./utils/mock-api";
|
||||
|
||||
// Issue #3482: the sidebar's "Recent chats" and the /workspace/chats list
|
||||
// page used to stop at the first 50 threads with no way to load more.
|
||||
// `useInfiniteThreads()` + a sentinel near the bottom of each list now
|
||||
// pages through the backend.
|
||||
|
||||
const TOTAL_THREADS = 120;
|
||||
const PAGE_SIZE = 50;
|
||||
|
||||
const THREADS = Array.from({ length: TOTAL_THREADS }, (_, i) => {
|
||||
// Pad index so titles sort deterministically as strings. The thread-search
|
||||
// mock returns threads in the order provided, so paging boundaries are
|
||||
// stable across runs.
|
||||
const index = String(i + 1).padStart(3, "0");
|
||||
return {
|
||||
thread_id: `00000000-0000-0000-0000-0000000${index.padStart(5, "0")}`,
|
||||
title: `Conversation ${index}`,
|
||||
updated_at: `2025-06-${String((i % 28) + 1).padStart(2, "0")}T12:00:00Z`,
|
||||
};
|
||||
});
|
||||
|
||||
const FIRST_PAGE_LAST = `Conversation ${String(PAGE_SIZE).padStart(3, "0")}`;
|
||||
const SECOND_PAGE_FIRST = `Conversation ${String(PAGE_SIZE + 1).padStart(3, "0")}`;
|
||||
|
||||
test.describe("Thread list infinite scroll (issue #3482)", () => {
|
||||
test("chats list page loads more threads when scrolling to the bottom", async ({
|
||||
page,
|
||||
}) => {
|
||||
mockLangGraphAPI(page, { threads: THREADS });
|
||||
|
||||
await page.goto("/workspace/chats");
|
||||
|
||||
const main = page.locator("main");
|
||||
|
||||
// First page renders.
|
||||
await expect(main.getByText(FIRST_PAGE_LAST)).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
// Items past the first page have not been fetched yet.
|
||||
await expect(main.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
|
||||
|
||||
// Scrolling the sentinel into view triggers the next page.
|
||||
const sentinel = page.getByTestId("chats-page-sentinel");
|
||||
await sentinel.scrollIntoViewIfNeeded();
|
||||
|
||||
await expect(main.getByText(SECOND_PAGE_FIRST)).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
});
|
||||
|
||||
test("sidebar recent chats loads more threads when scrolling to the bottom", async ({
|
||||
page,
|
||||
}) => {
|
||||
mockLangGraphAPI(page, { threads: THREADS });
|
||||
|
||||
await page.goto("/workspace/chats/new");
|
||||
|
||||
// The 50th thread (end of first page) appears in the sidebar.
|
||||
await expect(page.getByText(FIRST_PAGE_LAST).first()).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
// The 51st has not been fetched yet.
|
||||
await expect(page.getByText(SECOND_PAGE_FIRST)).toHaveCount(0);
|
||||
|
||||
// Scroll the sidebar sentinel into view to trigger the next page.
|
||||
const sentinel = page.getByTestId("recent-chat-list-sentinel");
|
||||
await sentinel.scrollIntoViewIfNeeded();
|
||||
|
||||
await expect(page.getByText(SECOND_PAGE_FIRST).first()).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
});
|
||||
|
||||
test("chats list page does NOT auto-paginate while a search filter is active", async ({
|
||||
page,
|
||||
}) => {
|
||||
// Count search requests via a passive request observer. Using
|
||||
// page.route() here would race with mockLangGraphAPI's fulfill route
|
||||
// (Playwright matches routes in reverse registration order), so the
|
||||
// counter could miss real requests. page.on('request') is a pure
|
||||
// observer and never interferes with routing.
|
||||
let searchRequestCount = 0;
|
||||
page.on("request", (request) => {
|
||||
if (request.url().includes("/api/langgraph/threads/search")) {
|
||||
searchRequestCount += 1;
|
||||
}
|
||||
});
|
||||
|
||||
mockLangGraphAPI(page, { threads: THREADS });
|
||||
|
||||
await page.goto("/workspace/chats");
|
||||
|
||||
// Wait for the first page to render so we have a baseline count.
|
||||
await expect(page.locator("main").getByText(FIRST_PAGE_LAST)).toBeVisible({
|
||||
timeout: 15_000,
|
||||
});
|
||||
const baselineRequests = searchRequestCount;
|
||||
|
||||
// Type a query that matches nothing in the first page (and nothing at
|
||||
// all, since titles are deterministic).
|
||||
await page
|
||||
.getByPlaceholder("Search chats")
|
||||
.fill("zzz-no-such-conversation");
|
||||
|
||||
// The auto-sentinel must be gone; an explicit button takes its place.
|
||||
await expect(page.getByTestId("chats-page-sentinel")).toHaveCount(0);
|
||||
await expect(page.getByTestId("chats-page-load-more")).toBeVisible();
|
||||
|
||||
// Give the IntersectionObserver a couple of frames to misbehave if the
|
||||
// guard regresses. No additional /threads/search calls should fire.
|
||||
await page.waitForTimeout(500);
|
||||
expect(searchRequestCount).toBe(baselineRequests);
|
||||
|
||||
// The explicit button still works as an escape hatch.
|
||||
await page.getByTestId("chats-page-load-more").click();
|
||||
await expect
|
||||
.poll(() => searchRequestCount, { timeout: 10_000 })
|
||||
.toBeGreaterThan(baselineRequests);
|
||||
});
|
||||
});
|
||||
@@ -85,7 +85,7 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
|
||||
const skills = options?.skills ?? DEFAULT_SKILLS;
|
||||
|
||||
// Thread search — sidebar thread list & chats list page
|
||||
void page.route("**/api/langgraph/threads/search", (route) => {
|
||||
void page.route("**/api/langgraph/threads/search", async (route) => {
|
||||
const body = threads.map((t) => ({
|
||||
thread_id: t.thread_id,
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
@@ -94,10 +94,33 @@ export function mockLangGraphAPI(page: Page, options?: MockAPIOptions) {
|
||||
status: "idle",
|
||||
values: { title: t.title ?? "Untitled" },
|
||||
}));
|
||||
|
||||
let limit: number | undefined;
|
||||
let offset = 0;
|
||||
try {
|
||||
const postData = route.request().postDataJSON() as {
|
||||
limit?: number;
|
||||
offset?: number;
|
||||
} | null;
|
||||
if (postData) {
|
||||
if (typeof postData.limit === "number") {
|
||||
limit = postData.limit;
|
||||
}
|
||||
if (typeof postData.offset === "number") {
|
||||
offset = postData.offset;
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// No / invalid JSON body — fall back to returning the full list.
|
||||
}
|
||||
|
||||
const sliced =
|
||||
typeof limit === "number" ? body.slice(offset, offset + limit) : body;
|
||||
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify(body),
|
||||
body: JSON.stringify(sliced),
|
||||
});
|
||||
});
|
||||
|
||||
|
||||
@@ -1,146 +0,0 @@
|
||||
import { beforeEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
vi.mock("@/core/api/fetcher", () => ({
|
||||
fetch: vi.fn(),
|
||||
}));
|
||||
|
||||
vi.mock("@/core/config", () => ({
|
||||
getBackendBaseURL: () => "/backend",
|
||||
}));
|
||||
|
||||
import { fetch as fetcher } from "@/core/api/fetcher";
|
||||
import {
|
||||
connectChannelProvider,
|
||||
disconnectChannelConnection,
|
||||
listChannelConnections,
|
||||
listChannelProviders,
|
||||
} from "@/core/channels/api";
|
||||
|
||||
const mockedFetch = vi.mocked(fetcher);
|
||||
|
||||
function jsonResponse(status: number, body: unknown): Response {
|
||||
return new Response(JSON.stringify(body), {
|
||||
status,
|
||||
statusText: status >= 400 ? "Bad Request" : "OK",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
});
|
||||
}
|
||||
|
||||
beforeEach(() => {
|
||||
mockedFetch.mockReset();
|
||||
});
|
||||
|
||||
describe("channels api", () => {
|
||||
test("loads provider catalog", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(200, {
|
||||
enabled: true,
|
||||
providers: [
|
||||
{
|
||||
provider: "telegram",
|
||||
display_name: "Telegram",
|
||||
enabled: true,
|
||||
configured: true,
|
||||
auth_mode: "deep_link",
|
||||
connection_status: "not_connected",
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(listChannelProviders()).resolves.toMatchObject({
|
||||
enabled: true,
|
||||
providers: [{ provider: "telegram", display_name: "Telegram" }],
|
||||
});
|
||||
expect(mockedFetch).toHaveBeenCalledWith("/backend/api/channels/providers");
|
||||
});
|
||||
|
||||
test("loads current user's connections", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(200, {
|
||||
connections: [
|
||||
{
|
||||
id: "connection-1",
|
||||
provider: "telegram",
|
||||
status: "connected",
|
||||
external_account_name: "Alice",
|
||||
scopes: [],
|
||||
metadata: {},
|
||||
},
|
||||
],
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(listChannelConnections()).resolves.toMatchObject([
|
||||
{ id: "connection-1", provider: "telegram", status: "connected" },
|
||||
]);
|
||||
expect(mockedFetch).toHaveBeenCalledWith(
|
||||
"/backend/api/channels/connections",
|
||||
);
|
||||
});
|
||||
|
||||
test("starts a provider connection flow", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(200, {
|
||||
provider: "telegram",
|
||||
mode: "deep_link",
|
||||
url: "https://t.me/deerflow_bot?start=state",
|
||||
code: "state",
|
||||
instruction: "Send /start state to the DeerFlow Telegram bot.",
|
||||
expires_in: 600,
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(connectChannelProvider("telegram")).resolves.toMatchObject({
|
||||
provider: "telegram",
|
||||
url: "https://t.me/deerflow_bot?start=state",
|
||||
instruction: "Send /start state to the DeerFlow Telegram bot.",
|
||||
});
|
||||
expect(mockedFetch).toHaveBeenCalledWith(
|
||||
"/backend/api/channels/telegram/connect",
|
||||
{ method: "POST" },
|
||||
);
|
||||
});
|
||||
|
||||
test("starts a binding-code connection flow", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(200, {
|
||||
provider: "slack",
|
||||
mode: "binding_code",
|
||||
url: null,
|
||||
code: "abc123",
|
||||
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
|
||||
expires_in: 600,
|
||||
}),
|
||||
);
|
||||
|
||||
await expect(connectChannelProvider("slack")).resolves.toMatchObject({
|
||||
provider: "slack",
|
||||
url: null,
|
||||
code: "abc123",
|
||||
instruction: "Send /connect abc123 to the DeerFlow Slack bot.",
|
||||
});
|
||||
});
|
||||
|
||||
test("disconnects a channel connection", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(new Response(null, { status: 204 }));
|
||||
|
||||
await expect(
|
||||
disconnectChannelConnection("connection-1"),
|
||||
).resolves.toBeUndefined();
|
||||
expect(mockedFetch).toHaveBeenCalledWith(
|
||||
"/backend/api/channels/connections/connection-1",
|
||||
{ method: "DELETE" },
|
||||
);
|
||||
});
|
||||
|
||||
test("uses backend detail for failed requests", async () => {
|
||||
mockedFetch.mockResolvedValueOnce(
|
||||
jsonResponse(400, { detail: "Channel provider is not configured" }),
|
||||
);
|
||||
|
||||
await expect(connectChannelProvider("slack")).rejects.toThrow(
|
||||
"Channel provider is not configured",
|
||||
);
|
||||
});
|
||||
});
|
||||
@@ -1,86 +0,0 @@
|
||||
import { afterEach, describe, expect, test, vi } from "vitest";
|
||||
|
||||
import {
|
||||
closeConnectWindow,
|
||||
openConnectUrl,
|
||||
prepareConnectWindow,
|
||||
} from "@/core/channels/open-connect-url";
|
||||
|
||||
type PopupStub = {
|
||||
closed: boolean;
|
||||
close: ReturnType<typeof vi.fn>;
|
||||
location: {
|
||||
replace: ReturnType<typeof vi.fn>;
|
||||
};
|
||||
opener: unknown;
|
||||
};
|
||||
|
||||
function stubWindow(openResult: PopupStub | null) {
|
||||
const assign = vi.fn();
|
||||
const open = vi.fn(() => openResult);
|
||||
vi.stubGlobal("window", {
|
||||
open,
|
||||
location: { assign },
|
||||
});
|
||||
return { assign, open };
|
||||
}
|
||||
|
||||
function makePopup(): PopupStub {
|
||||
return {
|
||||
closed: false,
|
||||
close: vi.fn(),
|
||||
location: { replace: vi.fn() },
|
||||
opener: {},
|
||||
};
|
||||
}
|
||||
|
||||
afterEach(() => {
|
||||
vi.unstubAllGlobals();
|
||||
});
|
||||
|
||||
describe("channel connect window helpers", () => {
|
||||
test("opens a blank tab synchronously and detaches opener", () => {
|
||||
const popup = makePopup();
|
||||
const { open } = stubWindow(popup);
|
||||
|
||||
const prepared = prepareConnectWindow();
|
||||
|
||||
expect(open).toHaveBeenCalledWith("about:blank", "_blank");
|
||||
expect(prepared).toBe(popup);
|
||||
expect(popup.opener).toBeNull();
|
||||
});
|
||||
|
||||
test("navigates a prepared popup without opening another window", () => {
|
||||
const popup = makePopup();
|
||||
const { assign, open } = stubWindow(null);
|
||||
|
||||
openConnectUrl(
|
||||
"https://t.me/deerflow_bot?start=state",
|
||||
popup as unknown as Window,
|
||||
);
|
||||
|
||||
expect(open).not.toHaveBeenCalled();
|
||||
expect(assign).not.toHaveBeenCalled();
|
||||
expect(popup.location.replace).toHaveBeenCalledWith(
|
||||
"https://t.me/deerflow_bot?start=state",
|
||||
);
|
||||
});
|
||||
|
||||
test("falls back to current-window navigation when no popup is available", () => {
|
||||
const { assign } = stubWindow(null);
|
||||
|
||||
openConnectUrl("https://t.me/deerflow_bot?start=state");
|
||||
|
||||
expect(assign).toHaveBeenCalledWith(
|
||||
"https://t.me/deerflow_bot?start=state",
|
||||
);
|
||||
});
|
||||
|
||||
test("closes a prepared popup on connect failure", () => {
|
||||
const popup = makePopup();
|
||||
|
||||
closeConnectWindow(popup as unknown as Window);
|
||||
|
||||
expect(popup.close).toHaveBeenCalled();
|
||||
});
|
||||
});
|
||||
@@ -0,0 +1,228 @@
|
||||
import { QueryClient, type InfiniteData } from "@tanstack/react-query";
|
||||
import { describe, expect, test } from "vitest";
|
||||
|
||||
import {
|
||||
filterInfiniteThreadsCache,
|
||||
getInfiniteThreadsNextPageParam,
|
||||
INFINITE_THREADS_PAGE_SIZE,
|
||||
INFINITE_THREADS_QUERY_KEY_PREFIX,
|
||||
mapInfiniteThreadsCache,
|
||||
upsertThreadInInfiniteCache,
|
||||
} from "@/core/threads/hooks";
|
||||
import type { AgentThread } from "@/core/threads/types";
|
||||
|
||||
// Issue #3482: the sidebar and /workspace/chats list used to be capped at
|
||||
// 50 threads because `useThreads()` exits as soon as `threads.length >=
|
||||
// params.limit`. These pure helpers back the `useInfiniteThreads()`
|
||||
// pagination logic and the mirrored cache writes that keep rename / delete
|
||||
// / stream-finish in sync with both the legacy array cache and the new
|
||||
// infinite cache.
|
||||
|
||||
function makeThread(id: string, title = `Title ${id}`): AgentThread {
|
||||
return {
|
||||
thread_id: id,
|
||||
created_at: "2025-01-01T00:00:00Z",
|
||||
updated_at: "2025-01-01T00:00:00Z",
|
||||
metadata: {},
|
||||
status: "idle",
|
||||
values: { title },
|
||||
} as unknown as AgentThread;
|
||||
}
|
||||
|
||||
function makePage(start: number, size: number): AgentThread[] {
|
||||
return Array.from({ length: size }, (_, i) => makeThread(`t-${start + i}`));
|
||||
}
|
||||
|
||||
function makeInfiniteData(pages: AgentThread[][]): InfiniteData<AgentThread[]> {
|
||||
return {
|
||||
pages,
|
||||
pageParams: pages.map((_, i) => i * INFINITE_THREADS_PAGE_SIZE),
|
||||
};
|
||||
}
|
||||
|
||||
describe("getInfiniteThreadsNextPageParam", () => {
|
||||
test("returns next offset when the last page is full", () => {
|
||||
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||
expect(getInfiniteThreadsNextPageParam(page1, [page1])).toBe(
|
||||
INFINITE_THREADS_PAGE_SIZE,
|
||||
);
|
||||
});
|
||||
|
||||
test("returns next offset across multiple full pages", () => {
|
||||
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||
const page2 = makePage(
|
||||
INFINITE_THREADS_PAGE_SIZE,
|
||||
INFINITE_THREADS_PAGE_SIZE,
|
||||
);
|
||||
expect(getInfiniteThreadsNextPageParam(page2, [page1, page2])).toBe(
|
||||
INFINITE_THREADS_PAGE_SIZE * 2,
|
||||
);
|
||||
});
|
||||
|
||||
test("returns undefined when the last page is short (end of list)", () => {
|
||||
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||
const page2 = makePage(INFINITE_THREADS_PAGE_SIZE, 10);
|
||||
expect(
|
||||
getInfiniteThreadsNextPageParam(page2, [page1, page2]),
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
test("returns undefined when the last page is empty", () => {
|
||||
const page1 = makePage(0, INFINITE_THREADS_PAGE_SIZE);
|
||||
expect(getInfiniteThreadsNextPageParam([], [page1, []])).toBeUndefined();
|
||||
});
|
||||
|
||||
test("respects a custom page size", () => {
|
||||
const page1 = makePage(0, 5);
|
||||
expect(getInfiniteThreadsNextPageParam(page1, [page1], 5)).toBe(5);
|
||||
expect(getInfiniteThreadsNextPageParam(page1, [page1], 10)).toBeUndefined();
|
||||
});
|
||||
});
|
||||
|
||||
describe("mapInfiniteThreadsCache", () => {
|
||||
test("returns undefined when oldData is undefined", () => {
|
||||
expect(mapInfiniteThreadsCache(undefined, (t) => t)).toBeUndefined();
|
||||
});
|
||||
|
||||
test("updates the matching thread across multiple pages", () => {
|
||||
const page1 = [makeThread("a"), makeThread("b")];
|
||||
const page2 = [makeThread("c"), makeThread("d")];
|
||||
const data = makeInfiniteData([page1, page2]);
|
||||
|
||||
const updated = mapInfiniteThreadsCache(data, (t) =>
|
||||
t.thread_id === "c"
|
||||
? { ...t, values: { ...t.values, title: "renamed" } }
|
||||
: t,
|
||||
);
|
||||
|
||||
expect(updated?.pages[0]?.[0]?.values?.title).toBe("Title a");
|
||||
expect(updated?.pages[1]?.[0]?.thread_id).toBe("c");
|
||||
expect(updated?.pages[1]?.[0]?.values?.title).toBe("renamed");
|
||||
expect(updated?.pages[1]?.[1]?.values?.title).toBe("Title d");
|
||||
});
|
||||
|
||||
test("preserves pageParams", () => {
|
||||
const data = makeInfiniteData([[makeThread("a")]]);
|
||||
const updated = mapInfiniteThreadsCache(data, (t) => t);
|
||||
expect(updated?.pageParams).toEqual(data.pageParams);
|
||||
});
|
||||
});
|
||||
|
||||
describe("filterInfiniteThreadsCache", () => {
|
||||
test("returns undefined when oldData is undefined", () => {
|
||||
expect(filterInfiniteThreadsCache(undefined, () => true)).toBeUndefined();
|
||||
});
|
||||
|
||||
test("removes matching threads across all pages", () => {
|
||||
const page1 = [makeThread("a"), makeThread("b")];
|
||||
const page2 = [makeThread("b"), makeThread("c")];
|
||||
const data = makeInfiniteData([page1, page2]);
|
||||
|
||||
const filtered = filterInfiniteThreadsCache(
|
||||
data,
|
||||
(t) => t.thread_id !== "b",
|
||||
);
|
||||
|
||||
expect(filtered?.pages[0]?.map((t) => t.thread_id)).toEqual(["a"]);
|
||||
expect(filtered?.pages[1]?.map((t) => t.thread_id)).toEqual(["c"]);
|
||||
});
|
||||
|
||||
test("keeps an emptied page as an empty array (does not drop the page)", () => {
|
||||
const page1 = [makeThread("a")];
|
||||
const page2 = [makeThread("b")];
|
||||
const data = makeInfiniteData([page1, page2]);
|
||||
|
||||
const filtered = filterInfiniteThreadsCache(
|
||||
data,
|
||||
(t) => t.thread_id !== "a",
|
||||
);
|
||||
|
||||
expect(filtered?.pages).toHaveLength(2);
|
||||
expect(filtered?.pages[0]).toEqual([]);
|
||||
expect(filtered?.pages[1]?.[0]?.thread_id).toBe("b");
|
||||
});
|
||||
|
||||
test("does not regress next offset when an earlier page has been shrunk by a delete", () => {
|
||||
// Simulate two full pages already loaded.
|
||||
const page1 = Array.from({ length: 50 }, (_, i) => ({
|
||||
thread_id: `a${i}`,
|
||||
}));
|
||||
const page2 = Array.from({ length: 50 }, (_, i) => ({
|
||||
thread_id: `b${i}`,
|
||||
}));
|
||||
|
||||
// Offset right after fetching page 2 (this is the value TanStack Query
|
||||
// freezes into pageParams).
|
||||
const offsetAfterPage2 = getInfiniteThreadsNextPageParam(
|
||||
page2 as unknown as AgentThread[],
|
||||
[page1, page2] as unknown as AgentThread[][],
|
||||
);
|
||||
expect(offsetAfterPage2).toBe(100);
|
||||
|
||||
// Now a delete mutation runs filterInfiniteThreadsCache and shrinks
|
||||
// page 1 from 50 to 49 entries. TanStack does NOT re-invoke
|
||||
// getNextPageParam on cache mutations; the previously-computed offset
|
||||
// (100) remains the param for the next fetchNextPage() call, so the
|
||||
// helper is consistent with how the library uses its return value.
|
||||
const shrunkPage1 = page1.slice(0, 49);
|
||||
const recomputed = getInfiniteThreadsNextPageParam(
|
||||
page2 as unknown as AgentThread[],
|
||||
[shrunkPage1, page2] as unknown as AgentThread[][],
|
||||
);
|
||||
// We document the recomputed value for completeness, but in practice
|
||||
// useDeleteThread invalidates the query in onSettled, so pages are
|
||||
// refetched from offset 0 rather than relying on this number.
|
||||
expect(recomputed).toBe(99);
|
||||
});
|
||||
});
|
||||
|
||||
describe("upsertThreadInInfiniteCache", () => {
|
||||
function seedClient(initial?: InfiniteData<AgentThread[]>): QueryClient {
|
||||
const client = new QueryClient();
|
||||
if (initial) {
|
||||
client.setQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}], initial);
|
||||
}
|
||||
return client;
|
||||
}
|
||||
|
||||
function readCache(
|
||||
client: QueryClient,
|
||||
): InfiniteData<AgentThread[]> | undefined {
|
||||
return client.getQueryData([...INFINITE_THREADS_QUERY_KEY_PREFIX, {}]);
|
||||
}
|
||||
|
||||
test("no-op when the infinite cache has not been initialised yet", () => {
|
||||
const client = seedClient();
|
||||
upsertThreadInInfiniteCache(client, makeThread("new"));
|
||||
expect(readCache(client)).toBeUndefined();
|
||||
});
|
||||
|
||||
test("prepends a brand-new thread to the first page", () => {
|
||||
const client = seedClient({
|
||||
pages: [[makeThread("a"), makeThread("b")]],
|
||||
pageParams: [0],
|
||||
});
|
||||
upsertThreadInInfiniteCache(client, makeThread("new"));
|
||||
const cache = readCache(client);
|
||||
expect(cache?.pages[0]?.map((t) => t.thread_id)).toEqual(["new", "a", "b"]);
|
||||
});
|
||||
|
||||
test("merges into the existing entry instead of duplicating it", () => {
|
||||
const existing = makeThread("a", "Old title");
|
||||
const client = seedClient({
|
||||
pages: [[existing, makeThread("b")]],
|
||||
pageParams: [0],
|
||||
});
|
||||
// Simulate an onCreated upsert that races with a thread already in cache:
|
||||
// the cache copy should win for title/metadata (it represents later state),
|
||||
// but no duplicate row should appear.
|
||||
upsertThreadInInfiniteCache(client, {
|
||||
...makeThread("a", "New title"),
|
||||
status: "busy",
|
||||
});
|
||||
const cache = readCache(client);
|
||||
const ids = cache?.pages[0]?.map((t) => t.thread_id);
|
||||
expect(ids).toEqual(["a", "b"]);
|
||||
expect(cache?.pages[0]?.[0]?.values.title).toBe("Old title");
|
||||
});
|
||||
});
|
||||
+13
-95
@@ -29,6 +29,14 @@ set -e
|
||||
REPO_ROOT="$(builtin cd "$(dirname "${BASH_SOURCE[0]}")/.." >/dev/null 2>&1 && pwd -P)"
|
||||
cd "$REPO_ROOT"
|
||||
|
||||
# ── Load .env ────────────────────────────────────────────────────────────────
|
||||
|
||||
if [ -f "$REPO_ROOT/.env" ]; then
|
||||
set -a
|
||||
source "$REPO_ROOT/.env"
|
||||
set +a
|
||||
fi
|
||||
|
||||
_pick_python() {
|
||||
local candidate
|
||||
for candidate in python3 python py; do
|
||||
@@ -40,61 +48,6 @@ _pick_python() {
|
||||
return 1
|
||||
}
|
||||
|
||||
_load_dotenv_file() {
|
||||
local env_file=$1
|
||||
local python_bin
|
||||
|
||||
[ -f "$env_file" ] || return 0
|
||||
|
||||
if ! python_bin="$(_pick_python)"; then
|
||||
echo "Python is required to load $env_file safely."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
eval "$("$python_bin" - "$env_file" <<'PY'
|
||||
import re
|
||||
import shlex
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
env_path = Path(sys.argv[1])
|
||||
assign_re = re.compile(r"^(?:export\s+)?([A-Za-z_][A-Za-z0-9_]*)\s*=\s*(.*)$")
|
||||
|
||||
|
||||
def strip_unquoted_comment(value: str) -> str:
|
||||
for index, char in enumerate(value):
|
||||
if char == "#" and (index == 0 or value[index - 1].isspace()):
|
||||
return value[:index].rstrip()
|
||||
return value
|
||||
|
||||
|
||||
for raw_line in env_path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#"):
|
||||
continue
|
||||
|
||||
match = assign_re.match(line)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
key, value = match.groups()
|
||||
value = value.strip()
|
||||
try:
|
||||
parsed = shlex.split(value, comments=True, posix=True)
|
||||
except ValueError:
|
||||
value = strip_unquoted_comment(value)
|
||||
else:
|
||||
value = parsed[0] if parsed else ""
|
||||
|
||||
print(f"export {key}={shlex.quote(value)}")
|
||||
PY
|
||||
)"
|
||||
}
|
||||
|
||||
# ── Load .env ────────────────────────────────────────────────────────────────
|
||||
|
||||
_load_dotenv_file "$REPO_ROOT/.env"
|
||||
|
||||
# ── Argument parsing ─────────────────────────────────────────────────────────
|
||||
|
||||
DEV_MODE=true
|
||||
@@ -226,10 +179,7 @@ _is_port_listening() {
|
||||
fi
|
||||
|
||||
if command -v netstat >/dev/null 2>&1; then
|
||||
if netstat -ltn 2>/dev/null | awk -v port="$port" '
|
||||
toupper($NF) == "LISTEN" && $4 ~ "(^|[.:])" port "$" { found = 1 }
|
||||
END { exit found ? 0 : 1 }
|
||||
'; then
|
||||
if netstat -ltn 2>/dev/null | awk '{print $4}' | grep -Eq "(^|[.:])${port}$"; then
|
||||
return 0
|
||||
fi
|
||||
fi
|
||||
@@ -237,21 +187,6 @@ _is_port_listening() {
|
||||
return 1
|
||||
}
|
||||
|
||||
_wait_for_port_free() {
|
||||
local port=$1
|
||||
local timeout=${2:-10}
|
||||
local elapsed=0
|
||||
|
||||
while _is_port_listening "$port"; do
|
||||
if [ "$elapsed" -ge "$timeout" ]; then
|
||||
echo " ⚠ Port $port is still in use after ${timeout}s"
|
||||
return 1
|
||||
fi
|
||||
sleep 1
|
||||
elapsed=$((elapsed + 1))
|
||||
done
|
||||
}
|
||||
|
||||
_is_repo_nginx_pid() {
|
||||
local pid=$1
|
||||
local command
|
||||
@@ -304,12 +239,9 @@ stop_all() {
|
||||
echo "Stopping all services..."
|
||||
_report_reclaimed_ports
|
||||
_kill_repo_processes "uvicorn app.gateway.app:app"
|
||||
_kill_repo_processes "pnpm .*run dev"
|
||||
_kill_repo_processes "next dev"
|
||||
_kill_repo_processes "next start"
|
||||
_kill_repo_processes "next-server"
|
||||
_kill_repo_processes "next/dist"
|
||||
_kill_repo_processes "turbopack"
|
||||
nginx -c "$REPO_ROOT/docker/nginx/nginx.local.conf" -p "$REPO_ROOT" -s quit 2>/dev/null || true
|
||||
sleep 1
|
||||
_kill_repo_nginx
|
||||
@@ -320,9 +252,6 @@ stop_all() {
|
||||
_kill_repo_port 8001
|
||||
_kill_repo_port 3000
|
||||
_kill_repo_port 2026
|
||||
_wait_for_port_free 8001 30 || true
|
||||
_wait_for_port_free 3000 30 || true
|
||||
_wait_for_port_free 2026 30 || true
|
||||
./scripts/cleanup-containers.sh deer-flow-sandbox 2>/dev/null || true
|
||||
echo "✓ All services stopped"
|
||||
}
|
||||
@@ -485,7 +414,6 @@ trap 'cleanup 143' TERM
|
||||
# In daemon mode, wraps with nohup. Waits for port to be ready.
|
||||
run_service() {
|
||||
local name="$1" cmd="$2" port="$3" timeout="$4"
|
||||
local service_pid
|
||||
|
||||
if _is_port_listening "$port"; then
|
||||
echo "✗ $name cannot start because port $port is already in use."
|
||||
@@ -495,14 +423,10 @@ run_service() {
|
||||
|
||||
echo "Starting $name..."
|
||||
if $DAEMON_MODE; then
|
||||
nohup sh -c "$cmd" < /dev/null > /dev/null 2>&1 &
|
||||
nohup sh -c "$cmd" > /dev/null 2>&1 &
|
||||
else
|
||||
sh -c "$cmd" &
|
||||
fi
|
||||
service_pid=$!
|
||||
if $DAEMON_MODE; then
|
||||
disown "$service_pid" 2>/dev/null || true
|
||||
fi
|
||||
|
||||
./scripts/wait-for-port.sh "$port" "$timeout" "$name" || {
|
||||
local logfile="logs/$(echo "$name" | tr '[:upper:]' '[:lower:]' | tr ' ' '-').log"
|
||||
@@ -510,12 +434,6 @@ run_service() {
|
||||
[ -f "$logfile" ] && tail -20 "$logfile"
|
||||
cleanup 1
|
||||
}
|
||||
if ! kill -0 "$service_pid" 2>/dev/null; then
|
||||
local logfile="logs/$(echo "$name" | tr '[:upper:]' '[:lower:]' | tr ' ' '-').log"
|
||||
echo "✗ $name process exited after port $port became available."
|
||||
[ -f "$logfile" ] && tail -20 "$logfile"
|
||||
cleanup 1
|
||||
fi
|
||||
echo "✓ $name started on localhost:$port"
|
||||
}
|
||||
|
||||
@@ -526,17 +444,17 @@ mkdir -p temp/client_body_temp temp/proxy_temp temp/fastcgi_temp temp/uwsgi_temp
|
||||
|
||||
# 1. Gateway API
|
||||
run_service "Gateway" \
|
||||
"cd backend && exec env PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 $GATEWAY_EXTRA_FLAGS > ../logs/gateway.log 2>&1" \
|
||||
"cd backend && PYTHONPATH=. uv run uvicorn app.gateway.app:app --host 0.0.0.0 --port 8001 $GATEWAY_EXTRA_FLAGS > ../logs/gateway.log 2>&1" \
|
||||
8001 30
|
||||
|
||||
# 2. Frontend
|
||||
run_service "Frontend" \
|
||||
"cd frontend && exec $FRONTEND_CMD > ../logs/frontend.log 2>&1" \
|
||||
"cd frontend && $FRONTEND_CMD > ../logs/frontend.log 2>&1" \
|
||||
3000 120
|
||||
|
||||
# 3. Nginx
|
||||
run_service "Nginx" \
|
||||
"exec nginx -g 'daemon off;' -c '$REPO_ROOT/docker/nginx/nginx.local.conf' -p '$REPO_ROOT' > logs/nginx.log 2>&1" \
|
||||
"nginx -g 'daemon off;' -c '$REPO_ROOT/docker/nginx/nginx.local.conf' -p '$REPO_ROOT' > logs/nginx.log 2>&1" \
|
||||
2026 10
|
||||
|
||||
# ── Ready ────────────────────────────────────────────────────────────────────
|
||||
|
||||
Reference in New Issue
Block a user