mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 21:55:59 +00:00
feat(community): add Serper Google Images provider for image_search (#3575)
* feat(community): add Serper Google Images provider for image_search Add a Serper-backed `image_search` tool alongside the existing Serper `web_search` provider, so users with a SERPER_API_KEY can pull Google Images results as reference images for downstream image generation. - Share request/response handling between web_search and image_search via `_serper_post` / `_response_items`, with bounded `max_results` (capped at 10) and query normalization. - Add a best-effort SSRF guard (`_safe_public_url`) that rejects non-http(s), localhost and private/non-global IP image URLs; filtered entries are dropped and never consume the result limit. - doctor: flag literal `api_key` values in config as a warning and steer users toward `.env` + `$SERPER_API_KEY`. - Docs/config: document the Serper image_search provider and SERPER_API_KEY, and discourage committing literal keys to config.yaml. - Tests: cover the provider end-to-end (100% line coverage on tools.py) and the doctor literal-key warning path. * fix(community): block obfuscated IPv4 literals in Serper image SSRF guard The image_search SSRF guard only rejected dotted-decimal IP literals; encoded forms such as decimal (http://2130706433/), hex (0x7f000001) and octal (0177.0.0.1) raised ValueError in ip_address() and were allowed through, even though many HTTP clients resolve them to private addresses like 127.0.0.1. Add _decode_ipv4() to permissively decode these inet_aton-style encodings and apply the same is_global check; hostnames that do not decode to an IP (e.g. cafe.com) are still treated as hosts and left to fetch-time re-validation. Addresses PR review feedback. Tests cover decimal/hex/octal loopback and private encodings plus non-IP edge cases; tools.py stays at 100% line coverage. * test(community): cover IPv4-mapped IPv6 URL filtering * fix(community): address Serper image search review feedback - Block trailing-dot hostname SSRF bypass (localhost./127.0.0.1.) in _safe_public_url by stripping the FQDN root label before checks. - Keep a filtered image/thumbnail URL empty instead of collapsing onto its counterpart, preserving the high-res/preview contract. - Evaluate the SSRF guard once per field rather than twice. - Treat a null-typed organic/images field as "no results" rather than a malformed payload. - doctor.py: when a config $VAR is unset, fall through to the default env var before reporting it as not set.
This commit is contained in:
@@ -1,3 +1,3 @@
|
||||
from .tools import web_search_tool
|
||||
from .tools import image_search_tool, web_search_tool
|
||||
|
||||
__all__ = ["web_search_tool"]
|
||||
__all__ = ["image_search_tool", "web_search_tool"]
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
"""
|
||||
Web Search Tool - Search the web using Serper (Google Search API).
|
||||
Web and image search tools powered by Serper (Google Search API).
|
||||
|
||||
Serper provides real-time Google Search results via a JSON API.
|
||||
An API key is required. Sign up at https://serper.dev to get one.
|
||||
Serper provides real-time Google Search and Google Images results via a JSON
|
||||
API. An API key is required. Sign up at https://serper.dev to get one.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from ipaddress import IPv4Address, ip_address
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import httpx
|
||||
from langchain.tools import tool
|
||||
@@ -16,43 +18,168 @@ from deerflow.config import get_app_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SERPER_ENDPOINT = "https://google.serper.dev/search"
|
||||
_api_key_warned = False
|
||||
_SERPER_SEARCH_ENDPOINT = "https://google.serper.dev/search"
|
||||
_SERPER_IMAGES_ENDPOINT = "https://google.serper.dev/images"
|
||||
_SERPER_MAX_RESULTS = 10
|
||||
_api_key_warned: set[str] = set()
|
||||
|
||||
|
||||
def _get_api_key() -> str | None:
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
def _get_api_key(tool_name: str) -> str | None:
|
||||
config = get_app_config().get_tool_config(tool_name)
|
||||
if config is not None:
|
||||
api_key = config.model_extra.get("api_key")
|
||||
if isinstance(api_key, str) and api_key.strip():
|
||||
return api_key
|
||||
return os.getenv("SERPER_API_KEY")
|
||||
return api_key.strip()
|
||||
env_key = os.getenv("SERPER_API_KEY")
|
||||
if isinstance(env_key, str) and env_key.strip():
|
||||
return env_key.strip()
|
||||
return None
|
||||
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Google Search via Serper.
|
||||
def _coerce_max_results(value: object, default: int = 5, max_allowed: int = _SERPER_MAX_RESULTS) -> int:
|
||||
"""Coerce config/parameter input into a bounded positive result count."""
|
||||
try:
|
||||
count = int(value)
|
||||
except (TypeError, ValueError):
|
||||
return default
|
||||
if count <= 0:
|
||||
return default
|
||||
return min(count, max_allowed)
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5.
|
||||
|
||||
def _missing_key_error(query: str, tool_name: str) -> str:
|
||||
if tool_name not in _api_key_warned:
|
||||
_api_key_warned.add(tool_name)
|
||||
logger.warning("Serper API key is not set for '%s'. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev", tool_name)
|
||||
return json.dumps(
|
||||
{"error": "SERPER_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def _unexpected_format_error(query: str) -> str:
|
||||
return json.dumps(
|
||||
{"error": "Serper returned an unexpected response format", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
|
||||
def _response_items(data: dict, field: str, query: str) -> tuple[list[dict] | None, str | None]:
|
||||
items = data.get(field)
|
||||
# Treat a missing or null field as "no results" (some APIs return
|
||||
# ``{"organic": null}`` to signal that) rather than a malformed payload.
|
||||
if items is None:
|
||||
return [], None
|
||||
if not isinstance(items, list):
|
||||
logger.error("Serper returned unexpected '%s' payload type: %s", field, type(items).__name__)
|
||||
return None, _unexpected_format_error(query)
|
||||
return [item for item in items if isinstance(item, dict)], None
|
||||
|
||||
|
||||
def _clean_query(query: str) -> str:
|
||||
"""Normalize a raw query into the value actually sent to Serper."""
|
||||
query = query.strip()
|
||||
if len(query) > 500:
|
||||
query = query[:500]
|
||||
return query
|
||||
|
||||
|
||||
def _decode_ipv4(host: str) -> IPv4Address | None:
|
||||
"""Decode obfuscated IPv4 literals that ``ip_address`` rejects.
|
||||
|
||||
Mirrors the permissive ``inet_aton`` parsing many HTTP clients use, so that
|
||||
integer (``2130706433``), hex (``0x7f000001``) and octal (``0177.0.0.1``)
|
||||
encodings of an address are recognized. Returns an ``IPv4Address`` when the
|
||||
host decodes to one, otherwise ``None`` (e.g. real domains like
|
||||
``cafe.com`` fail to decode and are left for the caller to treat as a host).
|
||||
"""
|
||||
global _api_key_warned
|
||||
parts = host.split(".")
|
||||
if not 1 <= len(parts) <= 4:
|
||||
return None
|
||||
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
values: list[int] = []
|
||||
for part in parts:
|
||||
if not part:
|
||||
return None
|
||||
try:
|
||||
if part.startswith(("0x", "0X")):
|
||||
values.append(int(part, 16))
|
||||
elif part.startswith("0") and len(part) > 1:
|
||||
values.append(int(part, 8))
|
||||
else:
|
||||
values.append(int(part, 10))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
api_key = _get_api_key()
|
||||
if not api_key:
|
||||
if not _api_key_warned:
|
||||
_api_key_warned = True
|
||||
logger.warning("Serper API key is not set. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev")
|
||||
return json.dumps(
|
||||
{"error": "SERPER_API_KEY is not configured", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
*leading, last = values
|
||||
for value in leading:
|
||||
if not 0 <= value <= 0xFF:
|
||||
return None
|
||||
max_last = (1 << (8 * (4 - len(leading)))) - 1
|
||||
if not 0 <= last <= max_last:
|
||||
return None
|
||||
|
||||
result = 0
|
||||
for value in leading:
|
||||
result = (result << 8) | value
|
||||
result = (result << (8 * (4 - len(leading)))) | last
|
||||
return ip_address(result)
|
||||
|
||||
|
||||
def _is_url_present(value: object) -> bool:
|
||||
"""Return ``True`` when *value* is a non-empty URL string.
|
||||
|
||||
Used to distinguish a field that was *absent* (eligible for cross-field
|
||||
fallback) from one that was *present but filtered* by the SSRF guard (which
|
||||
must stay empty rather than collapse onto its counterpart).
|
||||
"""
|
||||
return isinstance(value, str) and bool(value.strip())
|
||||
|
||||
|
||||
def _safe_public_url(value: object) -> str:
|
||||
"""Return ``value`` only if it is a safe, public http(s) URL, else "".
|
||||
|
||||
This is a best-effort SSRF guard that rejects non-http(s) schemes,
|
||||
``localhost``, and private/non-global IP literals (including obfuscated
|
||||
decimal/hex/octal encodings). It only inspects the URL string and cannot
|
||||
catch public hostnames that resolve to internal IPs (e.g. DNS rebinding);
|
||||
any consumer that actually downloads these URLs must re-validate the
|
||||
resolved IP at fetch time.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return ""
|
||||
url = value.strip()
|
||||
parsed = urlparse(url)
|
||||
if parsed.scheme not in {"http", "https"} or not parsed.netloc or not parsed.hostname:
|
||||
return ""
|
||||
|
||||
# Strip a single trailing dot (FQDN root label). ``localhost.`` and
|
||||
# ``127.0.0.1.`` resolve to loopback on common resolvers but would
|
||||
# otherwise slip past the localhost/IP checks below.
|
||||
host = parsed.hostname.lower().rstrip(".")
|
||||
if not host:
|
||||
return ""
|
||||
if host == "localhost" or host.endswith(".localhost"):
|
||||
return ""
|
||||
|
||||
try:
|
||||
ip = ip_address(host)
|
||||
except ValueError:
|
||||
ip = _decode_ipv4(host)
|
||||
if ip is None:
|
||||
return url
|
||||
return url if ip.is_global else ""
|
||||
|
||||
|
||||
def _serper_post(endpoint: str, api_key: str, query: str, max_results: int) -> tuple[dict | None, str | None]:
|
||||
"""Send a POST request to a Serper endpoint.
|
||||
|
||||
``query`` is expected to already be normalized via :func:`_clean_query`.
|
||||
|
||||
Returns a ``(data, error_json)`` tuple: on success ``data`` is the parsed
|
||||
JSON response and ``error_json`` is ``None``; on failure ``data`` is ``None``
|
||||
and ``error_json`` is a serialized structured error ready to return.
|
||||
"""
|
||||
headers = {
|
||||
"X-API-KEY": api_key,
|
||||
"Content-Type": "application/json",
|
||||
@@ -61,23 +188,56 @@ def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
|
||||
try:
|
||||
with httpx.Client(timeout=30) as client:
|
||||
response = client.post(_SERPER_ENDPOINT, headers=headers, json=payload)
|
||||
response = client.post(endpoint, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
if not isinstance(data, dict):
|
||||
logger.error("Serper returned an unexpected payload type: %s", type(data).__name__)
|
||||
return None, _unexpected_format_error(query)
|
||||
return data, None
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Serper API returned HTTP {e.response.status_code}: {e.response.text}")
|
||||
return json.dumps(
|
||||
resp_text = (e.response.text or "")[:500]
|
||||
logger.error("Serper API returned HTTP %s: %s", e.response.status_code, resp_text)
|
||||
return None, json.dumps(
|
||||
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Serper search failed: {type(e).__name__}: {e}")
|
||||
return json.dumps({"error": str(e), "query": query}, ensure_ascii=False)
|
||||
logger.error("Serper request failed: %s: %s", type(e).__name__, str(e)[:500])
|
||||
return None, json.dumps({"error": str(e)[:500], "query": query}, ensure_ascii=False)
|
||||
|
||||
organic = data.get("organic", [])
|
||||
|
||||
@tool("web_search", parse_docstring=True)
|
||||
def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search the web for information using Google Search via Serper.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing what you want to find. Be specific for better results.
|
||||
max_results: Maximum number of search results to return. Default is 5, capped at 10.
|
||||
"""
|
||||
config = get_app_config().get_tool_config("web_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
max_results = _coerce_max_results(max_results)
|
||||
query = _clean_query(query)
|
||||
|
||||
api_key = _get_api_key("web_search")
|
||||
if not api_key:
|
||||
return _missing_key_error(query, "web_search")
|
||||
|
||||
data, error_json = _serper_post(_SERPER_SEARCH_ENDPOINT, api_key, query, max_results)
|
||||
if error_json is not None:
|
||||
return error_json
|
||||
|
||||
organic, error_json = _response_items(data, "organic", query)
|
||||
if error_json is not None:
|
||||
return error_json
|
||||
if not organic:
|
||||
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
|
||||
|
||||
# Search result links are returned verbatim (not passed through
|
||||
# _safe_public_url): they are surfaced as citations for the model to read,
|
||||
# not fetched/downloaded by this tool, unlike image_search image URLs.
|
||||
normalized_results = [
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
@@ -93,3 +253,71 @@ def web_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"results": normalized_results,
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
|
||||
|
||||
@tool("image_search", parse_docstring=True)
|
||||
def image_search_tool(query: str, max_results: int = 5) -> str:
|
||||
"""Search for images online using Google Images via Serper. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
|
||||
|
||||
The returned image URLs can be used as reference images in image generation to significantly improve quality.
|
||||
|
||||
Args:
|
||||
query: Search keywords describing the images you want to find. Be specific for better results (e.g., "Japanese woman street photography 1990s" instead of just "woman").
|
||||
max_results: Maximum number of images to return. Default is 5, capped at 10.
|
||||
"""
|
||||
config = get_app_config().get_tool_config("image_search")
|
||||
if config is not None and "max_results" in config.model_extra:
|
||||
max_results = config.model_extra.get("max_results", max_results)
|
||||
max_results = _coerce_max_results(max_results)
|
||||
query = _clean_query(query)
|
||||
|
||||
api_key = _get_api_key("image_search")
|
||||
if not api_key:
|
||||
return _missing_key_error(query, "image_search")
|
||||
|
||||
data, error_json = _serper_post(_SERPER_IMAGES_ENDPOINT, api_key, query, max_results)
|
||||
if error_json is not None:
|
||||
return error_json
|
||||
|
||||
images, error_json = _response_items(data, "images", query)
|
||||
if error_json is not None:
|
||||
return error_json
|
||||
if not images:
|
||||
return json.dumps({"error": "No images found", "query": query}, ensure_ascii=False)
|
||||
|
||||
normalized_results = []
|
||||
for r in images:
|
||||
raw_image = r.get("imageUrl")
|
||||
raw_thumb = r.get("thumbnailUrl")
|
||||
# Evaluate the (non-trivial) SSRF guard once per field instead of twice.
|
||||
safe_image = _safe_public_url(raw_image)
|
||||
safe_thumb = _safe_public_url(raw_thumb)
|
||||
# Cross-fall back only when the other field was *absent*. A field that
|
||||
# was present but failed the SSRF filter is left empty rather than
|
||||
# collapsed onto its counterpart, so a dropped high-res URL never
|
||||
# silently masquerades as the preview (and vice versa), preserving the
|
||||
# high-res/preview contract callers rely on.
|
||||
image_url = safe_image or (safe_thumb if not _is_url_present(raw_image) else "")
|
||||
thumbnail_url = safe_thumb or (safe_image if not _is_url_present(raw_thumb) else "")
|
||||
if not image_url and not thumbnail_url:
|
||||
continue
|
||||
normalized_results.append(
|
||||
{
|
||||
"title": r.get("title", ""),
|
||||
"image_url": image_url,
|
||||
"thumbnail_url": thumbnail_url,
|
||||
}
|
||||
)
|
||||
if len(normalized_results) >= max_results:
|
||||
break
|
||||
|
||||
if not normalized_results:
|
||||
return json.dumps({"error": "No safe image URLs found", "query": query}, ensure_ascii=False)
|
||||
|
||||
output = {
|
||||
"query": query,
|
||||
"total_results": len(normalized_results),
|
||||
"results": normalized_results,
|
||||
"usage_hint": "Use the 'image_url' values as reference images in image generation. Download them first if needed.",
|
||||
}
|
||||
return json.dumps(output, indent=2, ensure_ascii=False)
|
||||
|
||||
Reference in New Issue
Block a user