Compare commits

...

6 Commits

Author SHA1 Message Date
laundry ba6198f3ec test: fix unit test error
Change-Id: I3dd7a6179132e5497a30ada443d88de0c47af3d4
2025-05-20 11:59:00 +08:00
laundry 28a01dfe0e test: fix unit test error
Change-Id: If4c4cd10673e76a30945674c7cda198aeabf28d0
2025-05-20 11:58:39 +08:00
laundry 3b1db26507 test: fix test error
Change-Id: I3997dc53a2cfaa35501a1fbda5902ee15528124e
2025-05-19 15:47:08 +08:00
laundry e927b556d6 test: add background node unit test
Change-Id: I9aabcf02ff04fda40c56f3ea22abe6b8f93bf9b6
2025-05-19 15:36:24 +08:00
laundry c2b8dd8e6a test: add background node unit test
Change-Id: Ia99f5a1687464387dcb01bbee04deaa371c6e490
2025-05-19 15:33:42 +08:00
DanielWalnut 8bbcdbe4de feat: config max_search_results for search engine (#192)
* feat: implement UI

* feat: config max_search_results for search engine via api

---------

Co-authored-by: Henry Li <henry1943@163.com>
2025-05-18 13:23:52 +08:00
16 changed files with 233 additions and 88 deletions
+2 -2
View File
@@ -1,6 +1,6 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .agents import research_agent, coder_agent
from .agents import create_agent
__all__ = ["research_agent", "coder_agent"]
__all__ = ["create_agent"]
-13
View File
@@ -4,12 +4,6 @@
from langgraph.prebuilt import create_react_agent
from src.prompts import apply_prompt_template
from src.tools import (
crawl_tool,
python_repl_tool,
web_search_tool,
)
from src.llms.llm import get_llm_by_type
from src.config.agents import AGENT_LLM_MAP
@@ -23,10 +17,3 @@ def create_agent(agent_name: str, agent_type: str, tools: list, prompt_template:
tools=tools,
prompt=lambda state: apply_prompt_template(prompt_template, state),
)
# Create agents using the factory function
research_agent = create_agent(
"researcher", "researcher", [web_search_tool, crawl_tool], "researcher"
)
coder_agent = create_agent("coder", "coder", [python_repl_tool], "coder")
+1 -2
View File
@@ -1,7 +1,7 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .tools import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from .tools import SELECTED_SEARCH_ENGINE, SearchEngine
from .loader import load_yaml_config
from .questions import BUILT_IN_QUESTIONS, BUILT_IN_QUESTIONS_ZH_CN
@@ -42,7 +42,6 @@ __all__ = [
# Other configurations
"TEAM_MEMBERS",
"TEAM_MEMBER_CONFIGRATIONS",
"SEARCH_MAX_RESULTS",
"SELECTED_SEARCH_ENGINE",
"SearchEngine",
"BUILT_IN_QUESTIONS",
+1
View File
@@ -14,6 +14,7 @@ class Configuration:
max_plan_iterations: int = 1 # Maximum number of plan iterations
max_step_num: int = 3 # Maximum number of steps in a plan
max_search_results: int = 3 # Maximum number of search results
mcp_settings: dict = None # MCP settings, including dynamic loaded tools
@classmethod
-1
View File
@@ -17,4 +17,3 @@ class SearchEngine(enum.Enum):
# Tool configuration
SELECTED_SEARCH_ENGINE = os.getenv("SEARCH_API", SearchEngine.TAVILY.value)
SEARCH_MAX_RESULTS = 3
+18 -16
View File
@@ -12,12 +12,11 @@ from langchain_core.tools import tool
from langgraph.types import Command, interrupt
from langchain_mcp_adapters.client import MultiServerMCPClient
from src.agents.agents import coder_agent, research_agent, create_agent
from src.agents import create_agent
from src.tools.search import LoggedTavilySearch
from src.tools import (
crawl_tool,
web_search_tool,
get_web_search_tool,
python_repl_tool,
)
@@ -29,7 +28,7 @@ from src.prompts.template import apply_prompt_template
from src.utils.json_utils import repair_json_output
from .types import State
from ..config import SEARCH_MAX_RESULTS, SELECTED_SEARCH_ENGINE, SearchEngine
from ..config import SELECTED_SEARCH_ENGINE, SearchEngine
logger = logging.getLogger(__name__)
@@ -45,13 +44,16 @@ def handoff_to_planner(
return
def background_investigation_node(state: State) -> Command[Literal["planner"]]:
def background_investigation_node(
state: State, config: RunnableConfig
) -> Command[Literal["planner"]]:
logger.info("background investigation node is running.")
configurable = Configuration.from_runnable_config(config)
query = state["messages"][-1].content
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY:
searched_content = LoggedTavilySearch(max_results=SEARCH_MAX_RESULTS).invoke(
{"query": query}
)
searched_content = LoggedTavilySearch(
max_results=configurable.max_search_results
).invoke({"query": query})
background_investigation_results = None
if isinstance(searched_content, list):
background_investigation_results = [
@@ -63,7 +65,9 @@ def background_investigation_node(state: State) -> Command[Literal["planner"]]:
f"Tavily search returned malformed response: {searched_content}"
)
else:
background_investigation_results = web_search_tool.invoke(query)
background_investigation_results = get_web_search_tool(
configurable.max_search_results
).invoke(query)
return Command(
update={
"background_investigation_results": json.dumps(
@@ -403,7 +407,6 @@ async def _setup_and_execute_agent_step(
state: State,
config: RunnableConfig,
agent_type: str,
default_agent,
default_tools: list,
) -> Command[Literal["research_team"]]:
"""Helper function to set up an agent with appropriate tools and execute a step.
@@ -417,7 +420,6 @@ async def _setup_and_execute_agent_step(
state: The current state
config: The runnable config
agent_type: The type of agent ("researcher" or "coder")
default_agent: The default agent to use if no MCP servers are configured
default_tools: The default tools to add to the agent
Returns:
@@ -455,8 +457,9 @@ async def _setup_and_execute_agent_step(
agent = create_agent(agent_type, agent_type, loaded_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
else:
# Use default agent if no MCP servers are configured
return await _execute_agent_step(state, default_agent, agent_type)
# Use default tools if no MCP servers are configured
agent = create_agent(agent_type, agent_type, default_tools, agent_type)
return await _execute_agent_step(state, agent, agent_type)
async def researcher_node(
@@ -464,12 +467,12 @@ async def researcher_node(
) -> Command[Literal["research_team"]]:
"""Researcher node that do research"""
logger.info("Researcher node is researching.")
configurable = Configuration.from_runnable_config(config)
return await _setup_and_execute_agent_step(
state,
config,
"researcher",
research_agent,
[web_search_tool, crawl_tool],
[get_web_search_tool(configurable.max_search_results), crawl_tool],
)
@@ -482,6 +485,5 @@ async def coder_node(
state,
config,
"coder",
coder_agent,
[python_repl_tool],
)
+2 -3
View File
@@ -44,13 +44,12 @@ def get_llm_by_type(
return llm
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
# In the future, we will use reasoning_llm and vl_llm for different purposes
# reasoning_llm = get_llm_by_type("reasoning")
# vl_llm = get_llm_by_type("vision")
if __name__ == "__main__":
# Initialize LLMs for different purposes - now these will be cached
basic_llm = get_llm_by_type("basic")
print(basic_llm.invoke("Hello"))
+3
View File
@@ -61,6 +61,7 @@ async def chat_stream(request: ChatRequest):
thread_id,
request.max_plan_iterations,
request.max_step_num,
request.max_search_results,
request.auto_accepted_plan,
request.interrupt_feedback,
request.mcp_settings,
@@ -75,6 +76,7 @@ async def _astream_workflow_generator(
thread_id: str,
max_plan_iterations: int,
max_step_num: int,
max_search_results: int,
auto_accepted_plan: bool,
interrupt_feedback: str,
mcp_settings: dict,
@@ -101,6 +103,7 @@ async def _astream_workflow_generator(
"thread_id": thread_id,
"max_plan_iterations": max_plan_iterations,
"max_step_num": max_step_num,
"max_search_results": max_search_results,
"mcp_settings": mcp_settings,
},
stream_mode=["messages", "updates"],
+3
View File
@@ -38,6 +38,9 @@ class ChatRequest(BaseModel):
max_step_num: Optional[int] = Field(
3, description="The maximum number of steps in a plan"
)
max_search_results: Optional[int] = Field(
3, description="The maximum number of search results"
)
auto_accepted_plan: Optional[bool] = Field(
False, description="Whether to automatically accept the plan"
)
+2 -18
View File
@@ -5,28 +5,12 @@ import os
from .crawl import crawl_tool
from .python_repl import python_repl_tool
from .search import (
tavily_search_tool,
duckduckgo_search_tool,
brave_search_tool,
arxiv_search_tool,
)
from .search import get_web_search_tool
from .tts import VolcengineTTS
from src.config import SELECTED_SEARCH_ENGINE, SearchEngine
# Map search engine names to their respective tools
search_tool_mappings = {
SearchEngine.TAVILY.value: tavily_search_tool,
SearchEngine.DUCKDUCKGO.value: duckduckgo_search_tool,
SearchEngine.BRAVE_SEARCH.value: brave_search_tool,
SearchEngine.ARXIV.value: arxiv_search_tool,
}
web_search_tool = search_tool_mappings.get(SELECTED_SEARCH_ENGINE, tavily_search_tool)
__all__ = [
"crawl_tool",
"web_search_tool",
"python_repl_tool",
"get_web_search_tool",
"VolcengineTTS",
]
+37 -33
View File
@@ -9,7 +9,7 @@ from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
from langchain_community.tools.arxiv import ArxivQueryRun
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
from src.config import SEARCH_MAX_RESULTS, SearchEngine
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
@@ -18,44 +18,48 @@ from src.tools.decorators import create_logged_tool
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
if os.getenv("SEARCH_API", "") == SearchEngine.TAVILY.value:
tavily_search_tool = LoggedTavilySearch(
name="web_search",
max_results=SEARCH_MAX_RESULTS,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
else:
tavily_search_tool = None
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
duckduckgo_search_tool = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS
)
LoggedBraveSearch = create_logged_tool(BraveSearch)
brave_search_tool = LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": SEARCH_MAX_RESULTS},
),
)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
arxiv_search_tool = LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=SEARCH_MAX_RESULTS,
load_max_docs=SEARCH_MAX_RESULTS,
load_all_available_meta=True,
),
)
# Get the selected search tool
def get_web_search_tool(max_search_results: int):
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
return LoggedTavilySearch(
name="web_search",
max_results=max_search_results,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(name="web_search", max_results=max_search_results)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": max_search_results},
),
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
return LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=max_search_results,
load_max_docs=max_search_results,
load_all_available_meta=True,
),
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")
if __name__ == "__main__":
results = LoggedDuckDuckGoSearch(
name="web_search", max_results=SEARCH_MAX_RESULTS, output_format="list"
name="web_search", max_results=3, output_format="list"
).invoke("cute panda")
print(json.dumps(results, indent=2, ensure_ascii=False))
+130
View File
@@ -0,0 +1,130 @@
import json
import pytest
from unittest.mock import patch, MagicMock
# 在这里 mock 掉 get_llm_by_type,避免 ValueError
with patch("src.llms.llm.get_llm_by_type", return_value=MagicMock()):
from langgraph.types import Command
from src.graph.nodes import background_investigation_node
from src.config import SearchEngine
from langchain_core.messages import HumanMessage
# Mock data
MOCK_SEARCH_RESULTS = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
@pytest.fixture
def mock_state():
return {
"messages": [HumanMessage(content="test query")],
"background_investigation_results": None,
}
@pytest.fixture
def mock_configurable():
mock = MagicMock()
mock.max_search_results = 5
return mock
@pytest.fixture
def mock_config():
# 你可以根据实际需要返回一个 MagicMock 或 dict
return MagicMock()
@pytest.fixture
def patch_config_from_runnable_config(mock_configurable):
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
yield
@pytest.fixture
def mock_tavily_search():
with patch("src.graph.nodes.LoggedTavilySearch") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.fixture
def mock_web_search_tool():
with patch("src.graph.nodes.get_web_search_tool") as mock:
instance = mock.return_value
instance.invoke.return_value = [
{"title": "Test Title 1", "content": "Test Content 1"},
{"title": "Test Title 2", "content": "Test Content 2"},
]
yield mock
@pytest.mark.parametrize("search_engine", [SearchEngine.TAVILY, "other"])
def test_background_investigation_node_tavily(
mock_state,
mock_tavily_search,
mock_web_search_tool,
search_engine,
patch_config_from_runnable_config,
mock_config,
):
"""Test background_investigation_node with Tavily search engine"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", search_engine):
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert isinstance(results, list)
if search_engine == SearchEngine.TAVILY:
mock_tavily_search.return_value.invoke.assert_called_once_with(
{"query": "test query"}
)
assert len(results) == 2
assert results[0]["title"] == "Test Title 1"
assert results[0]["content"] == "Test Content 1"
else:
mock_web_search_tool.return_value.invoke.assert_called_once_with(
"test query"
)
assert len(results) == 2
def test_background_investigation_node_malformed_response(
mock_state, mock_tavily_search, patch_config_from_runnable_config, mock_config
):
"""Test background_investigation_node with malformed Tavily response"""
with patch("src.graph.nodes.SELECTED_SEARCH_ENGINE", SearchEngine.TAVILY):
# Mock a malformed response
mock_tavily_search.return_value.invoke.return_value = "invalid response"
result = background_investigation_node(mock_state, mock_config)
# Verify the result structure
assert isinstance(result, Command)
assert result.goto == "planner"
# Verify the update contains background_investigation_results
update = result.update
assert "background_investigation_results" in update
# Parse and verify the JSON content
results = json.loads(update["background_investigation_results"])
assert results is None
+27
View File
@@ -32,6 +32,9 @@ const generalFormSchema = z.object({
maxStepNum: z.number().min(1, {
message: "Max step number must be at least 1.",
}),
maxSearchResults: z.number().min(1, {
message: "Max search results must be at least 1.",
}),
});
export const GeneralTab: Tab = ({
@@ -143,6 +146,30 @@ export const GeneralTab: Tab = ({
</FormItem>
)}
/>
<FormField
control={form.control}
name="maxSearchResults"
render={({ field }) => (
<FormItem>
<FormLabel>Max search results</FormLabel>
<FormControl>
<Input
className="w-60"
type="number"
defaultValue={field.value}
min={1}
onChange={(event) =>
field.onChange(parseInt(event.target.value || "0"))
}
/>
</FormControl>
<FormDescription>
By default, each search step has 3 results.
</FormDescription>
<FormMessage />
</FormItem>
)}
/>
</form>
</Form>
</main>
+4
View File
@@ -18,6 +18,7 @@ export async function* chatStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
enable_background_investigation: boolean;
mcp_settings?: {
@@ -61,12 +62,14 @@ async function* chatReplayStream(
auto_accepted_plan: boolean;
max_plan_iterations: number;
max_step_num: number;
max_search_results?: number;
interrupt_feedback?: string;
} = {
thread_id: "__mock__",
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
interrupt_feedback: undefined,
},
options: { abortSignal?: AbortSignal } = {},
@@ -157,6 +160,7 @@ export async function fetchReplayTitle() {
auto_accepted_plan: false,
max_plan_iterations: 3,
max_step_num: 1,
max_search_results: 3,
},
{},
);
+2
View File
@@ -13,6 +13,7 @@ const DEFAULT_SETTINGS: SettingsState = {
enableBackgroundInvestigation: false,
maxPlanIterations: 1,
maxStepNum: 3,
maxSearchResults: 3,
},
mcp: {
servers: [],
@@ -25,6 +26,7 @@ export type SettingsState = {
enableBackgroundInvestigation: boolean;
maxPlanIterations: number;
maxStepNum: number;
maxSearchResults: number;
};
mcp: {
servers: MCPServerMetadata[];
+1
View File
@@ -104,6 +104,7 @@ export async function sendMessage(
settings.enableBackgroundInvestigation ?? true,
max_plan_iterations: settings.maxPlanIterations,
max_step_num: settings.maxStepNum,
max_search_results: settings.maxSearchResults,
mcp_settings: settings.mcpSettings,
},
options,