Files
deer-flow/src/tools/search.py
T
Willem Jiang 4c2fe2e7f5 test: add more unit tests of tools (#315)
* test: add more test on test_tts.py

* test: add unit test of search and retriever in tools

* test: remove the main code of search.py

* test: add the travily_search unit test

* reformate the codes

* test: add unit tests of tools

* Added the pytest-asyncio dependency

* added the license header of test_tavily_search_api_wrapper.py
2025-06-12 20:43:32 +08:00

62 lines
2.2 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import json
import logging
import os
from langchain_community.tools import BraveSearch, DuckDuckGoSearchResults
from langchain_community.tools.arxiv import ArxivQueryRun
from langchain_community.utilities import ArxivAPIWrapper, BraveSearchWrapper
from src.config import SearchEngine, SELECTED_SEARCH_ENGINE
from src.tools.tavily_search.tavily_search_results_with_images import (
TavilySearchResultsWithImages,
)
from src.tools.decorators import create_logged_tool
logger = logging.getLogger(__name__)
# Create logged versions of the search tools
LoggedTavilySearch = create_logged_tool(TavilySearchResultsWithImages)
LoggedDuckDuckGoSearch = create_logged_tool(DuckDuckGoSearchResults)
LoggedBraveSearch = create_logged_tool(BraveSearch)
LoggedArxivSearch = create_logged_tool(ArxivQueryRun)
# Get the selected search tool
def get_web_search_tool(max_search_results: int):
if SELECTED_SEARCH_ENGINE == SearchEngine.TAVILY.value:
return LoggedTavilySearch(
name="web_search",
max_results=max_search_results,
include_raw_content=True,
include_images=True,
include_image_descriptions=True,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.DUCKDUCKGO.value:
return LoggedDuckDuckGoSearch(
name="web_search",
num_results=max_search_results,
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.BRAVE_SEARCH.value:
return LoggedBraveSearch(
name="web_search",
search_wrapper=BraveSearchWrapper(
api_key=os.getenv("BRAVE_SEARCH_API_KEY", ""),
search_kwargs={"count": max_search_results},
),
)
elif SELECTED_SEARCH_ENGINE == SearchEngine.ARXIV.value:
return LoggedArxivSearch(
name="web_search",
api_wrapper=ArxivAPIWrapper(
top_k_results=max_search_results,
load_max_docs=max_search_results,
load_all_available_meta=True,
),
)
else:
raise ValueError(f"Unsupported search engine: {SELECTED_SEARCH_ENGINE}")