Let setup wizard enable IM channels

This commit is contained in:
taohe
2026-06-11 17:34:22 +08:00
parent a270e8b310
commit ddd1c5e42f
6 changed files with 204 additions and 2 deletions
+3 -1
View File
@@ -1986,10 +1986,12 @@ class TestChannelManager:
_run(go())
def test_same_topic_reuses_thread(self):
def test_same_topic_reuses_thread(self, monkeypatch):
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
from app.channels.manager import ChannelManager
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
async def go():
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
+75
View File
@@ -7,7 +7,9 @@ Run from repo root:
from __future__ import annotations
import yaml
from wizard import ui as wizard_ui
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
from wizard.steps import channels as channels_step
from wizard.steps import llm as llm_step
from wizard.steps import search as search_step
from wizard.writer import (
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
def test_can_enable_selected_channel_connections(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=["feishu", "slack"],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is True
assert channel_connections["feishu"]["enabled"] is True
assert channel_connections["slack"]["enabled"] is True
assert channel_connections["telegram"]["enabled"] is False
assert channel_connections["discord"]["enabled"] is False
assert channel_connections["dingtalk"]["enabled"] is False
assert channel_connections["wechat"]["enabled"] is False
assert channel_connections["wecom"]["enabled"] is False
def test_channel_connections_disabled_when_no_channels_selected(self):
content = build_minimal_config(
provider_use="langchain_openai:ChatOpenAI",
model_name="gpt-4o",
display_name="OpenAI",
api_key_field="api_key",
env_var="OPENAI_API_KEY",
channel_connection_providers=[],
)
data = yaml.safe_load(content)
channel_connections = data["channel_connections"]
assert channel_connections["enabled"] is False
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
class TestLLMStep:
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
@@ -384,6 +424,41 @@ class TestLLMStep:
assert result.base_url == "https://gateway.example/v1"
class TestChannelsStep:
def test_returns_selected_channel_keys(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
result = channels_step.run_channels_step()
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
def test_empty_selection_disables_channel_connections(self, monkeypatch):
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
result = channels_step.run_channels_step()
assert result.enabled_providers == []
class TestWizardUi:
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
answers = iter(["", "2"])
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
monkeypatch.setattr("builtins.input", lambda _prompt: "")
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
# ---------------------------------------------------------------------------
# writer.py — env file helpers
# ---------------------------------------------------------------------------
+10 -1
View File
@@ -58,7 +58,7 @@ def main() -> int:
return 0
print()
total_steps = 4
total_steps = 5
from wizard.steps.llm import run_llm_step
@@ -76,6 +76,10 @@ def main() -> int:
execution = run_execution_step(f"Step 3/{total_steps}")
from wizard.steps.channels import run_channels_step
channels = run_channels_step(f"Step 4/{total_steps}")
print_header(f"Step {total_steps}/{total_steps} · Writing configuration")
write_config_yaml(
@@ -97,6 +101,7 @@ def main() -> int:
allow_host_bash=execution.allow_host_bash,
include_bash_tool=execution.include_bash_tool,
include_write_tools=execution.include_write_tools,
channel_connection_providers=channels.enabled_providers,
)
print_success(f"Config written to: {config_path.relative_to(project_root)}")
@@ -148,6 +153,10 @@ def main() -> int:
print(f" {green('')} File write: enabled")
else:
print(f" {'':>3} File write: disabled")
if channels.enabled_providers:
print(f" {green('')} IM channels: {', '.join(channels.enabled_providers)}")
else:
print(f" {'':>3} IM channels: disabled")
print()
print("Next steps:")
print(f" {cyan('make install')} # Install dependencies (first time only)")
+46
View File
@@ -0,0 +1,46 @@
"""Step: browser-connectable IM channel enablement."""
from __future__ import annotations
from dataclasses import dataclass
from wizard.ui import ask_multi_choice, print_header, print_info, print_success
CHANNEL_CONNECTION_OPTIONS: tuple[tuple[str, str, str], ...] = (
("telegram", "Telegram", "direct messages through your DeerFlow bot"),
("slack", "Slack", "workspace messages and mentions"),
("discord", "Discord", "server messages through your DeerFlow bot"),
("feishu", "Feishu / Lark", "messages through your DeerFlow app"),
("dingtalk", "DingTalk", "Stream Push messages through your DeerFlow bot"),
("wechat", "WeChat", "iLink messages through your DeerFlow bot"),
("wecom", "WeCom", "messages through your DeerFlow AI bot"),
)
@dataclass
class ChannelConnectionsStepResult:
enabled_providers: list[str]
def run_channels_step(step_label: str = "Step 4/5") -> ChannelConnectionsStepResult:
print_header(f"{step_label} · IM Channels (optional)")
print_info("Choose which IM channels should appear in the DeerFlow sidebar and Settings.")
print_info("Credentials can be entered later from the browser with Connect or Modify.")
print()
options = [f"{display_name}{description}" for _, display_name, description in CHANNEL_CONNECTION_OPTIONS]
selected = ask_multi_choice(
"Enable channels (comma-separated numbers, 'all', or Enter for none)",
options,
default=[],
)
enabled_providers = [CHANNEL_CONNECTION_OPTIONS[idx][0] for idx in selected]
if enabled_providers:
display_names = [CHANNEL_CONNECTION_OPTIONS[idx][1] for idx in selected]
print_success(f"Enabled channels: {', '.join(display_names)}")
else:
print_info("No IM channels selected; channel connections will stay disabled.")
return ChannelConnectionsStepResult(enabled_providers=enabled_providers)
+43
View File
@@ -224,6 +224,49 @@ def ask_choice(prompt: str, options: list[str], default: int | None = None) -> i
return _ask_choice_with_numbers(prompt, options, default=default)
def ask_multi_choice(prompt: str, options: list[str], default: list[int] | None = None) -> list[int]:
"""Present a numbered multi-select menu and return 0-based indexes."""
has_default = default is not None
default_indexes = list(default or [])
for i, opt in enumerate(options, 1):
marker = f" {green('*')}" if has_default and i - 1 in default_indexes else " "
print(f"{marker} {i}. {opt}")
print()
suffix = ""
if default_indexes:
suffix = f" [{','.join(str(idx + 1) for idx in default_indexes)}]"
elif has_default:
suffix = " [none]"
while True:
raw = input(f"{prompt}{suffix}: ").strip().lower()
if raw == "" and has_default:
return default_indexes
if raw in {"none", "no", "n", "skip"}:
return []
if raw == "all":
return list(range(len(options)))
parts = [part.strip() for part in raw.replace(" ", ",").split(",") if part.strip()]
selected: list[int] = []
valid = bool(parts)
for part in parts:
if not part.isdigit():
valid = False
break
idx = int(part) - 1
if not 0 <= idx < len(options):
valid = False
break
if idx not in selected:
selected.append(idx)
if valid:
return selected
print(f" Enter comma-separated numbers between 1 and {len(options)}, 'all', or 'none'.")
def ask_text(prompt: str, default: str = "", required: bool = False) -> str:
"""Ask for a text value, returning default if the user presses Enter."""
suffix = f" [{default}]" if default else ""
+27
View File
@@ -12,6 +12,16 @@ from typing import Any
import yaml
CHANNEL_CONNECTION_PROVIDERS: tuple[str, ...] = (
"telegram",
"slack",
"discord",
"feishu",
"dingtalk",
"wechat",
"wecom",
)
def _project_root() -> Path:
return Path(__file__).resolve().parents[2]
@@ -151,6 +161,18 @@ def _make_model_config_name(model_name: str) -> str:
return base.replace(".", "-")
def _build_channel_connections_config(enabled_providers: list[str]) -> dict[str, Any]:
selected = set(enabled_providers)
unknown = selected.difference(CHANNEL_CONNECTION_PROVIDERS)
if unknown:
raise ValueError(f"Unknown channel connection provider(s): {', '.join(sorted(unknown))}")
return {
"enabled": bool(selected),
**{provider: {"enabled": provider in selected} for provider in CHANNEL_CONNECTION_PROVIDERS},
}
def build_minimal_config(
*,
provider_use: str,
@@ -170,6 +192,7 @@ def build_minimal_config(
allow_host_bash: bool = False,
include_bash_tool: bool = False,
include_write_tools: bool = True,
channel_connection_providers: list[str] | None = None,
config_version: int = 5,
base_config: dict[str, Any] | None = None,
) -> str:
@@ -219,6 +242,8 @@ def build_minimal_config(
else:
sandbox_config.pop("allow_host_bash", None)
data["sandbox"] = sandbox_config
if channel_connection_providers is not None:
data["channel_connections"] = _build_channel_connections_config(channel_connection_providers)
header = (
f"# DeerFlow Configuration\n"
@@ -250,6 +275,7 @@ def write_config_yaml(
allow_host_bash: bool = False,
include_bash_tool: bool = False,
include_write_tools: bool = True,
channel_connection_providers: list[str] | None = None,
) -> None:
"""Write (or overwrite) config.yaml with a minimal working configuration."""
# Read config_version from config.example.yaml if present
@@ -284,6 +310,7 @@ def write_config_yaml(
allow_host_bash=allow_host_bash,
include_bash_tool=include_bash_tool,
include_write_tools=include_write_tools,
channel_connection_providers=channel_connection_providers,
config_version=config_version,
base_config=example_defaults,
)