mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Let setup wizard enable IM channels
This commit is contained in:
@@ -1986,10 +1986,12 @@ class TestChannelManager:
|
|||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
def test_same_topic_reuses_thread(self):
|
def test_same_topic_reuses_thread(self, monkeypatch):
|
||||||
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
|
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
bus = MessageBus()
|
bus = MessageBus()
|
||||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
|
|||||||
@@ -7,7 +7,9 @@ Run from repo root:
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
from wizard import ui as wizard_ui
|
||||||
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
|
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
|
||||||
|
from wizard.steps import channels as channels_step
|
||||||
from wizard.steps import llm as llm_step
|
from wizard.steps import llm as llm_step
|
||||||
from wizard.steps import search as search_step
|
from wizard.steps import search as search_step
|
||||||
from wizard.writer import (
|
from wizard.writer import (
|
||||||
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
|
|||||||
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
|
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
|
||||||
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
|
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
|
||||||
|
|
||||||
|
def test_can_enable_selected_channel_connections(self):
|
||||||
|
content = build_minimal_config(
|
||||||
|
provider_use="langchain_openai:ChatOpenAI",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
display_name="OpenAI",
|
||||||
|
api_key_field="api_key",
|
||||||
|
env_var="OPENAI_API_KEY",
|
||||||
|
channel_connection_providers=["feishu", "slack"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
channel_connections = data["channel_connections"]
|
||||||
|
|
||||||
|
assert channel_connections["enabled"] is True
|
||||||
|
assert channel_connections["feishu"]["enabled"] is True
|
||||||
|
assert channel_connections["slack"]["enabled"] is True
|
||||||
|
assert channel_connections["telegram"]["enabled"] is False
|
||||||
|
assert channel_connections["discord"]["enabled"] is False
|
||||||
|
assert channel_connections["dingtalk"]["enabled"] is False
|
||||||
|
assert channel_connections["wechat"]["enabled"] is False
|
||||||
|
assert channel_connections["wecom"]["enabled"] is False
|
||||||
|
|
||||||
|
def test_channel_connections_disabled_when_no_channels_selected(self):
|
||||||
|
content = build_minimal_config(
|
||||||
|
provider_use="langchain_openai:ChatOpenAI",
|
||||||
|
model_name="gpt-4o",
|
||||||
|
display_name="OpenAI",
|
||||||
|
api_key_field="api_key",
|
||||||
|
env_var="OPENAI_API_KEY",
|
||||||
|
channel_connection_providers=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = yaml.safe_load(content)
|
||||||
|
channel_connections = data["channel_connections"]
|
||||||
|
|
||||||
|
assert channel_connections["enabled"] is False
|
||||||
|
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
|
||||||
|
|
||||||
|
|
||||||
class TestLLMStep:
|
class TestLLMStep:
|
||||||
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
|
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
|
||||||
@@ -384,6 +424,41 @@ class TestLLMStep:
|
|||||||
assert result.base_url == "https://gateway.example/v1"
|
assert result.base_url == "https://gateway.example/v1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestChannelsStep:
|
||||||
|
def test_returns_selected_channel_keys(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
|
||||||
|
|
||||||
|
result = channels_step.run_channels_step()
|
||||||
|
|
||||||
|
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
|
||||||
|
|
||||||
|
def test_empty_selection_disables_channel_connections(self, monkeypatch):
|
||||||
|
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||||
|
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
|
||||||
|
|
||||||
|
result = channels_step.run_channels_step()
|
||||||
|
|
||||||
|
assert result.enabled_providers == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestWizardUi:
|
||||||
|
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
|
||||||
|
answers = iter(["", "2"])
|
||||||
|
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
|
||||||
|
|
||||||
|
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
|
||||||
|
|
||||||
|
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
|
||||||
|
monkeypatch.setattr("builtins.input", lambda _prompt: "")
|
||||||
|
|
||||||
|
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# writer.py — env file helpers
|
# writer.py — env file helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
+10
-1
@@ -58,7 +58,7 @@ def main() -> int:
|
|||||||
return 0
|
return 0
|
||||||
print()
|
print()
|
||||||
|
|
||||||
total_steps = 4
|
total_steps = 5
|
||||||
|
|
||||||
from wizard.steps.llm import run_llm_step
|
from wizard.steps.llm import run_llm_step
|
||||||
|
|
||||||
@@ -76,6 +76,10 @@ def main() -> int:
|
|||||||
|
|
||||||
execution = run_execution_step(f"Step 3/{total_steps}")
|
execution = run_execution_step(f"Step 3/{total_steps}")
|
||||||
|
|
||||||
|
from wizard.steps.channels import run_channels_step
|
||||||
|
|
||||||
|
channels = run_channels_step(f"Step 4/{total_steps}")
|
||||||
|
|
||||||
print_header(f"Step {total_steps}/{total_steps} · Writing configuration")
|
print_header(f"Step {total_steps}/{total_steps} · Writing configuration")
|
||||||
|
|
||||||
write_config_yaml(
|
write_config_yaml(
|
||||||
@@ -97,6 +101,7 @@ def main() -> int:
|
|||||||
allow_host_bash=execution.allow_host_bash,
|
allow_host_bash=execution.allow_host_bash,
|
||||||
include_bash_tool=execution.include_bash_tool,
|
include_bash_tool=execution.include_bash_tool,
|
||||||
include_write_tools=execution.include_write_tools,
|
include_write_tools=execution.include_write_tools,
|
||||||
|
channel_connection_providers=channels.enabled_providers,
|
||||||
)
|
)
|
||||||
print_success(f"Config written to: {config_path.relative_to(project_root)}")
|
print_success(f"Config written to: {config_path.relative_to(project_root)}")
|
||||||
|
|
||||||
@@ -148,6 +153,10 @@ def main() -> int:
|
|||||||
print(f" {green('✓')} File write: enabled")
|
print(f" {green('✓')} File write: enabled")
|
||||||
else:
|
else:
|
||||||
print(f" {'—':>3} File write: disabled")
|
print(f" {'—':>3} File write: disabled")
|
||||||
|
if channels.enabled_providers:
|
||||||
|
print(f" {green('✓')} IM channels: {', '.join(channels.enabled_providers)}")
|
||||||
|
else:
|
||||||
|
print(f" {'—':>3} IM channels: disabled")
|
||||||
print()
|
print()
|
||||||
print("Next steps:")
|
print("Next steps:")
|
||||||
print(f" {cyan('make install')} # Install dependencies (first time only)")
|
print(f" {cyan('make install')} # Install dependencies (first time only)")
|
||||||
|
|||||||
@@ -0,0 +1,46 @@
|
|||||||
|
"""Step: browser-connectable IM channel enablement."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from wizard.ui import ask_multi_choice, print_header, print_info, print_success
|
||||||
|
|
||||||
|
|
||||||
|
CHANNEL_CONNECTION_OPTIONS: tuple[tuple[str, str, str], ...] = (
|
||||||
|
("telegram", "Telegram", "direct messages through your DeerFlow bot"),
|
||||||
|
("slack", "Slack", "workspace messages and mentions"),
|
||||||
|
("discord", "Discord", "server messages through your DeerFlow bot"),
|
||||||
|
("feishu", "Feishu / Lark", "messages through your DeerFlow app"),
|
||||||
|
("dingtalk", "DingTalk", "Stream Push messages through your DeerFlow bot"),
|
||||||
|
("wechat", "WeChat", "iLink messages through your DeerFlow bot"),
|
||||||
|
("wecom", "WeCom", "messages through your DeerFlow AI bot"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChannelConnectionsStepResult:
|
||||||
|
enabled_providers: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
def run_channels_step(step_label: str = "Step 4/5") -> ChannelConnectionsStepResult:
|
||||||
|
print_header(f"{step_label} · IM Channels (optional)")
|
||||||
|
print_info("Choose which IM channels should appear in the DeerFlow sidebar and Settings.")
|
||||||
|
print_info("Credentials can be entered later from the browser with Connect or Modify.")
|
||||||
|
print()
|
||||||
|
|
||||||
|
options = [f"{display_name} — {description}" for _, display_name, description in CHANNEL_CONNECTION_OPTIONS]
|
||||||
|
selected = ask_multi_choice(
|
||||||
|
"Enable channels (comma-separated numbers, 'all', or Enter for none)",
|
||||||
|
options,
|
||||||
|
default=[],
|
||||||
|
)
|
||||||
|
enabled_providers = [CHANNEL_CONNECTION_OPTIONS[idx][0] for idx in selected]
|
||||||
|
|
||||||
|
if enabled_providers:
|
||||||
|
display_names = [CHANNEL_CONNECTION_OPTIONS[idx][1] for idx in selected]
|
||||||
|
print_success(f"Enabled channels: {', '.join(display_names)}")
|
||||||
|
else:
|
||||||
|
print_info("No IM channels selected; channel connections will stay disabled.")
|
||||||
|
|
||||||
|
return ChannelConnectionsStepResult(enabled_providers=enabled_providers)
|
||||||
@@ -224,6 +224,49 @@ def ask_choice(prompt: str, options: list[str], default: int | None = None) -> i
|
|||||||
return _ask_choice_with_numbers(prompt, options, default=default)
|
return _ask_choice_with_numbers(prompt, options, default=default)
|
||||||
|
|
||||||
|
|
||||||
|
def ask_multi_choice(prompt: str, options: list[str], default: list[int] | None = None) -> list[int]:
|
||||||
|
"""Present a numbered multi-select menu and return 0-based indexes."""
|
||||||
|
has_default = default is not None
|
||||||
|
default_indexes = list(default or [])
|
||||||
|
for i, opt in enumerate(options, 1):
|
||||||
|
marker = f" {green('*')}" if has_default and i - 1 in default_indexes else " "
|
||||||
|
print(f"{marker} {i}. {opt}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
suffix = ""
|
||||||
|
if default_indexes:
|
||||||
|
suffix = f" [{','.join(str(idx + 1) for idx in default_indexes)}]"
|
||||||
|
elif has_default:
|
||||||
|
suffix = " [none]"
|
||||||
|
|
||||||
|
while True:
|
||||||
|
raw = input(f"{prompt}{suffix}: ").strip().lower()
|
||||||
|
if raw == "" and has_default:
|
||||||
|
return default_indexes
|
||||||
|
if raw in {"none", "no", "n", "skip"}:
|
||||||
|
return []
|
||||||
|
if raw == "all":
|
||||||
|
return list(range(len(options)))
|
||||||
|
|
||||||
|
parts = [part.strip() for part in raw.replace(" ", ",").split(",") if part.strip()]
|
||||||
|
selected: list[int] = []
|
||||||
|
valid = bool(parts)
|
||||||
|
for part in parts:
|
||||||
|
if not part.isdigit():
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
idx = int(part) - 1
|
||||||
|
if not 0 <= idx < len(options):
|
||||||
|
valid = False
|
||||||
|
break
|
||||||
|
if idx not in selected:
|
||||||
|
selected.append(idx)
|
||||||
|
if valid:
|
||||||
|
return selected
|
||||||
|
|
||||||
|
print(f" Enter comma-separated numbers between 1 and {len(options)}, 'all', or 'none'.")
|
||||||
|
|
||||||
|
|
||||||
def ask_text(prompt: str, default: str = "", required: bool = False) -> str:
|
def ask_text(prompt: str, default: str = "", required: bool = False) -> str:
|
||||||
"""Ask for a text value, returning default if the user presses Enter."""
|
"""Ask for a text value, returning default if the user presses Enter."""
|
||||||
suffix = f" [{default}]" if default else ""
|
suffix = f" [{default}]" if default else ""
|
||||||
|
|||||||
@@ -12,6 +12,16 @@ from typing import Any
|
|||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
|
CHANNEL_CONNECTION_PROVIDERS: tuple[str, ...] = (
|
||||||
|
"telegram",
|
||||||
|
"slack",
|
||||||
|
"discord",
|
||||||
|
"feishu",
|
||||||
|
"dingtalk",
|
||||||
|
"wechat",
|
||||||
|
"wecom",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _project_root() -> Path:
|
def _project_root() -> Path:
|
||||||
return Path(__file__).resolve().parents[2]
|
return Path(__file__).resolve().parents[2]
|
||||||
@@ -151,6 +161,18 @@ def _make_model_config_name(model_name: str) -> str:
|
|||||||
return base.replace(".", "-")
|
return base.replace(".", "-")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_channel_connections_config(enabled_providers: list[str]) -> dict[str, Any]:
|
||||||
|
selected = set(enabled_providers)
|
||||||
|
unknown = selected.difference(CHANNEL_CONNECTION_PROVIDERS)
|
||||||
|
if unknown:
|
||||||
|
raise ValueError(f"Unknown channel connection provider(s): {', '.join(sorted(unknown))}")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"enabled": bool(selected),
|
||||||
|
**{provider: {"enabled": provider in selected} for provider in CHANNEL_CONNECTION_PROVIDERS},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def build_minimal_config(
|
def build_minimal_config(
|
||||||
*,
|
*,
|
||||||
provider_use: str,
|
provider_use: str,
|
||||||
@@ -170,6 +192,7 @@ def build_minimal_config(
|
|||||||
allow_host_bash: bool = False,
|
allow_host_bash: bool = False,
|
||||||
include_bash_tool: bool = False,
|
include_bash_tool: bool = False,
|
||||||
include_write_tools: bool = True,
|
include_write_tools: bool = True,
|
||||||
|
channel_connection_providers: list[str] | None = None,
|
||||||
config_version: int = 5,
|
config_version: int = 5,
|
||||||
base_config: dict[str, Any] | None = None,
|
base_config: dict[str, Any] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -219,6 +242,8 @@ def build_minimal_config(
|
|||||||
else:
|
else:
|
||||||
sandbox_config.pop("allow_host_bash", None)
|
sandbox_config.pop("allow_host_bash", None)
|
||||||
data["sandbox"] = sandbox_config
|
data["sandbox"] = sandbox_config
|
||||||
|
if channel_connection_providers is not None:
|
||||||
|
data["channel_connections"] = _build_channel_connections_config(channel_connection_providers)
|
||||||
|
|
||||||
header = (
|
header = (
|
||||||
f"# DeerFlow Configuration\n"
|
f"# DeerFlow Configuration\n"
|
||||||
@@ -250,6 +275,7 @@ def write_config_yaml(
|
|||||||
allow_host_bash: bool = False,
|
allow_host_bash: bool = False,
|
||||||
include_bash_tool: bool = False,
|
include_bash_tool: bool = False,
|
||||||
include_write_tools: bool = True,
|
include_write_tools: bool = True,
|
||||||
|
channel_connection_providers: list[str] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Write (or overwrite) config.yaml with a minimal working configuration."""
|
"""Write (or overwrite) config.yaml with a minimal working configuration."""
|
||||||
# Read config_version from config.example.yaml if present
|
# Read config_version from config.example.yaml if present
|
||||||
@@ -284,6 +310,7 @@ def write_config_yaml(
|
|||||||
allow_host_bash=allow_host_bash,
|
allow_host_bash=allow_host_bash,
|
||||||
include_bash_tool=include_bash_tool,
|
include_bash_tool=include_bash_tool,
|
||||||
include_write_tools=include_write_tools,
|
include_write_tools=include_write_tools,
|
||||||
|
channel_connection_providers=channel_connection_providers,
|
||||||
config_version=config_version,
|
config_version=config_version,
|
||||||
base_config=example_defaults,
|
base_config=example_defaults,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user