diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index c72c46fad..85aa6f7da 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -1986,10 +1986,12 @@ class TestChannelManager: _run(go()) - def test_same_topic_reuses_thread(self): + def test_same_topic_reuses_thread(self, monkeypatch): """Messages with the same topic_id should reuse the same DeerFlow thread.""" from app.channels.manager import ChannelManager + monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False) + async def go(): bus = MessageBus() store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json") diff --git a/backend/tests/test_setup_wizard.py b/backend/tests/test_setup_wizard.py index 9eecb2eae..5f8be4ae0 100644 --- a/backend/tests/test_setup_wizard.py +++ b/backend/tests/test_setup_wizard.py @@ -7,7 +7,9 @@ Run from repo root: from __future__ import annotations import yaml +from wizard import ui as wizard_ui from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider +from wizard.steps import channels as channels_step from wizard.steps import llm as llm_step from wizard.steps import search as search_step from wizard.writer import ( @@ -327,6 +329,44 @@ class TestBuildMinimalConfig: assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled" assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled" + def test_can_enable_selected_channel_connections(self): + content = build_minimal_config( + provider_use="langchain_openai:ChatOpenAI", + model_name="gpt-4o", + display_name="OpenAI", + api_key_field="api_key", + env_var="OPENAI_API_KEY", + channel_connection_providers=["feishu", "slack"], + ) + + data = yaml.safe_load(content) + channel_connections = data["channel_connections"] + + assert channel_connections["enabled"] is True + assert channel_connections["feishu"]["enabled"] is True + assert channel_connections["slack"]["enabled"] is True + assert channel_connections["telegram"]["enabled"] is False + assert channel_connections["discord"]["enabled"] is False + assert channel_connections["dingtalk"]["enabled"] is False + assert channel_connections["wechat"]["enabled"] is False + assert channel_connections["wecom"]["enabled"] is False + + def test_channel_connections_disabled_when_no_channels_selected(self): + content = build_minimal_config( + provider_use="langchain_openai:ChatOpenAI", + model_name="gpt-4o", + display_name="OpenAI", + api_key_field="api_key", + env_var="OPENAI_API_KEY", + channel_connection_providers=[], + ) + + data = yaml.safe_load(content) + channel_connections = data["channel_connections"] + + assert channel_connections["enabled"] is False + assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled") + class TestLLMStep: def test_model_selection_defaults_to_provider_default_model(self, monkeypatch): @@ -384,6 +424,41 @@ class TestLLMStep: assert result.base_url == "https://gateway.example/v1" +class TestChannelsStep: + def test_returns_selected_channel_keys(self, monkeypatch): + monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6]) + + result = channels_step.run_channels_step() + + assert result.enabled_providers == ["telegram", "feishu", "wecom"] + + def test_empty_selection_disables_channel_connections(self, monkeypatch): + monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None) + monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: []) + + result = channels_step.run_channels_step() + + assert result.enabled_providers == [] + + +class TestWizardUi: + def test_multi_choice_blank_requires_input_without_default(self, monkeypatch): + answers = iter(["", "2"]) + monkeypatch.setattr("builtins.input", lambda _prompt: next(answers)) + + assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1] + + def test_multi_choice_blank_accepts_empty_default(self, monkeypatch): + monkeypatch.setattr("builtins.input", lambda _prompt: "") + + assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == [] + + # --------------------------------------------------------------------------- # writer.py — env file helpers # --------------------------------------------------------------------------- diff --git a/scripts/setup_wizard.py b/scripts/setup_wizard.py index c3a7baf55..2b86f5f6b 100644 --- a/scripts/setup_wizard.py +++ b/scripts/setup_wizard.py @@ -58,7 +58,7 @@ def main() -> int: return 0 print() - total_steps = 4 + total_steps = 5 from wizard.steps.llm import run_llm_step @@ -76,6 +76,10 @@ def main() -> int: execution = run_execution_step(f"Step 3/{total_steps}") + from wizard.steps.channels import run_channels_step + + channels = run_channels_step(f"Step 4/{total_steps}") + print_header(f"Step {total_steps}/{total_steps} · Writing configuration") write_config_yaml( @@ -97,6 +101,7 @@ def main() -> int: allow_host_bash=execution.allow_host_bash, include_bash_tool=execution.include_bash_tool, include_write_tools=execution.include_write_tools, + channel_connection_providers=channels.enabled_providers, ) print_success(f"Config written to: {config_path.relative_to(project_root)}") @@ -148,6 +153,10 @@ def main() -> int: print(f" {green('✓')} File write: enabled") else: print(f" {'—':>3} File write: disabled") + if channels.enabled_providers: + print(f" {green('✓')} IM channels: {', '.join(channels.enabled_providers)}") + else: + print(f" {'—':>3} IM channels: disabled") print() print("Next steps:") print(f" {cyan('make install')} # Install dependencies (first time only)") diff --git a/scripts/wizard/steps/channels.py b/scripts/wizard/steps/channels.py new file mode 100644 index 000000000..302bd8fe8 --- /dev/null +++ b/scripts/wizard/steps/channels.py @@ -0,0 +1,46 @@ +"""Step: browser-connectable IM channel enablement.""" + +from __future__ import annotations + +from dataclasses import dataclass + +from wizard.ui import ask_multi_choice, print_header, print_info, print_success + + +CHANNEL_CONNECTION_OPTIONS: tuple[tuple[str, str, str], ...] = ( + ("telegram", "Telegram", "direct messages through your DeerFlow bot"), + ("slack", "Slack", "workspace messages and mentions"), + ("discord", "Discord", "server messages through your DeerFlow bot"), + ("feishu", "Feishu / Lark", "messages through your DeerFlow app"), + ("dingtalk", "DingTalk", "Stream Push messages through your DeerFlow bot"), + ("wechat", "WeChat", "iLink messages through your DeerFlow bot"), + ("wecom", "WeCom", "messages through your DeerFlow AI bot"), +) + + +@dataclass +class ChannelConnectionsStepResult: + enabled_providers: list[str] + + +def run_channels_step(step_label: str = "Step 4/5") -> ChannelConnectionsStepResult: + print_header(f"{step_label} · IM Channels (optional)") + print_info("Choose which IM channels should appear in the DeerFlow sidebar and Settings.") + print_info("Credentials can be entered later from the browser with Connect or Modify.") + print() + + options = [f"{display_name} — {description}" for _, display_name, description in CHANNEL_CONNECTION_OPTIONS] + selected = ask_multi_choice( + "Enable channels (comma-separated numbers, 'all', or Enter for none)", + options, + default=[], + ) + enabled_providers = [CHANNEL_CONNECTION_OPTIONS[idx][0] for idx in selected] + + if enabled_providers: + display_names = [CHANNEL_CONNECTION_OPTIONS[idx][1] for idx in selected] + print_success(f"Enabled channels: {', '.join(display_names)}") + else: + print_info("No IM channels selected; channel connections will stay disabled.") + + return ChannelConnectionsStepResult(enabled_providers=enabled_providers) diff --git a/scripts/wizard/ui.py b/scripts/wizard/ui.py index 289652f25..82ffe3033 100644 --- a/scripts/wizard/ui.py +++ b/scripts/wizard/ui.py @@ -224,6 +224,49 @@ def ask_choice(prompt: str, options: list[str], default: int | None = None) -> i return _ask_choice_with_numbers(prompt, options, default=default) +def ask_multi_choice(prompt: str, options: list[str], default: list[int] | None = None) -> list[int]: + """Present a numbered multi-select menu and return 0-based indexes.""" + has_default = default is not None + default_indexes = list(default or []) + for i, opt in enumerate(options, 1): + marker = f" {green('*')}" if has_default and i - 1 in default_indexes else " " + print(f"{marker} {i}. {opt}") + print() + + suffix = "" + if default_indexes: + suffix = f" [{','.join(str(idx + 1) for idx in default_indexes)}]" + elif has_default: + suffix = " [none]" + + while True: + raw = input(f"{prompt}{suffix}: ").strip().lower() + if raw == "" and has_default: + return default_indexes + if raw in {"none", "no", "n", "skip"}: + return [] + if raw == "all": + return list(range(len(options))) + + parts = [part.strip() for part in raw.replace(" ", ",").split(",") if part.strip()] + selected: list[int] = [] + valid = bool(parts) + for part in parts: + if not part.isdigit(): + valid = False + break + idx = int(part) - 1 + if not 0 <= idx < len(options): + valid = False + break + if idx not in selected: + selected.append(idx) + if valid: + return selected + + print(f" Enter comma-separated numbers between 1 and {len(options)}, 'all', or 'none'.") + + def ask_text(prompt: str, default: str = "", required: bool = False) -> str: """Ask for a text value, returning default if the user presses Enter.""" suffix = f" [{default}]" if default else "" diff --git a/scripts/wizard/writer.py b/scripts/wizard/writer.py index e2324340e..911f60cb8 100644 --- a/scripts/wizard/writer.py +++ b/scripts/wizard/writer.py @@ -12,6 +12,16 @@ from typing import Any import yaml +CHANNEL_CONNECTION_PROVIDERS: tuple[str, ...] = ( + "telegram", + "slack", + "discord", + "feishu", + "dingtalk", + "wechat", + "wecom", +) + def _project_root() -> Path: return Path(__file__).resolve().parents[2] @@ -151,6 +161,18 @@ def _make_model_config_name(model_name: str) -> str: return base.replace(".", "-") +def _build_channel_connections_config(enabled_providers: list[str]) -> dict[str, Any]: + selected = set(enabled_providers) + unknown = selected.difference(CHANNEL_CONNECTION_PROVIDERS) + if unknown: + raise ValueError(f"Unknown channel connection provider(s): {', '.join(sorted(unknown))}") + + return { + "enabled": bool(selected), + **{provider: {"enabled": provider in selected} for provider in CHANNEL_CONNECTION_PROVIDERS}, + } + + def build_minimal_config( *, provider_use: str, @@ -170,6 +192,7 @@ def build_minimal_config( allow_host_bash: bool = False, include_bash_tool: bool = False, include_write_tools: bool = True, + channel_connection_providers: list[str] | None = None, config_version: int = 5, base_config: dict[str, Any] | None = None, ) -> str: @@ -219,6 +242,8 @@ def build_minimal_config( else: sandbox_config.pop("allow_host_bash", None) data["sandbox"] = sandbox_config + if channel_connection_providers is not None: + data["channel_connections"] = _build_channel_connections_config(channel_connection_providers) header = ( f"# DeerFlow Configuration\n" @@ -250,6 +275,7 @@ def write_config_yaml( allow_host_bash: bool = False, include_bash_tool: bool = False, include_write_tools: bool = True, + channel_connection_providers: list[str] | None = None, ) -> None: """Write (or overwrite) config.yaml with a minimal working configuration.""" # Read config_version from config.example.yaml if present @@ -284,6 +310,7 @@ def write_config_yaml( allow_host_bash=allow_host_bash, include_bash_tool=include_bash_tool, include_write_tools=include_write_tools, + channel_connection_providers=channel_connection_providers, config_version=config_version, base_config=example_defaults, )