Files
deer-flow/backend/packages/harness/deerflow/community/serper/tools.py
T
Ryker_Feng 0bbbbc06f4 feat(community): add Serper Google Images provider for image_search (#3575)
* feat(community): add Serper Google Images provider for image_search

Add a Serper-backed `image_search` tool alongside the existing Serper
`web_search` provider, so users with a SERPER_API_KEY can pull Google
Images results as reference images for downstream image generation.

- Share request/response handling between web_search and image_search
  via `_serper_post` / `_response_items`, with bounded `max_results`
  (capped at 10) and query normalization.
- Add a best-effort SSRF guard (`_safe_public_url`) that rejects
  non-http(s), localhost and private/non-global IP image URLs; filtered
  entries are dropped and never consume the result limit.
- doctor: flag literal `api_key` values in config as a warning and steer
  users toward `.env` + `$SERPER_API_KEY`.
- Docs/config: document the Serper image_search provider and SERPER_API_KEY,
  and discourage committing literal keys to config.yaml.
- Tests: cover the provider end-to-end (100% line coverage on tools.py)
  and the doctor literal-key warning path.

* fix(community): block obfuscated IPv4 literals in Serper image SSRF guard

The image_search SSRF guard only rejected dotted-decimal IP literals; encoded
forms such as decimal (http://2130706433/), hex (0x7f000001) and octal
(0177.0.0.1) raised ValueError in ip_address() and were allowed through, even
though many HTTP clients resolve them to private addresses like 127.0.0.1.

Add _decode_ipv4() to permissively decode these inet_aton-style encodings and
apply the same is_global check; hostnames that do not decode to an IP (e.g.
cafe.com) are still treated as hosts and left to fetch-time re-validation.

Addresses PR review feedback. Tests cover decimal/hex/octal loopback and
private encodings plus non-IP edge cases; tools.py stays at 100% line coverage.

* test(community): cover IPv4-mapped IPv6 URL filtering

* fix(community): address Serper image search review feedback

- Block trailing-dot hostname SSRF bypass (localhost./127.0.0.1.) in
  _safe_public_url by stripping the FQDN root label before checks.
- Keep a filtered image/thumbnail URL empty instead of collapsing onto
  its counterpart, preserving the high-res/preview contract.
- Evaluate the SSRF guard once per field rather than twice.
- Treat a null-typed organic/images field as "no results" rather than a
  malformed payload.
- doctor.py: when a config $VAR is unset, fall through to the default env
  var before reporting it as not set.
2026-06-18 07:36:35 +08:00

324 lines
12 KiB
Python

"""
Web and image search tools powered by Serper (Google Search API).
Serper provides real-time Google Search and Google Images results via a JSON
API. An API key is required. Sign up at https://serper.dev to get one.
"""
import json
import logging
import os
from ipaddress import IPv4Address, ip_address
from urllib.parse import urlparse
import httpx
from langchain.tools import tool
from deerflow.config import get_app_config
logger = logging.getLogger(__name__)
_SERPER_SEARCH_ENDPOINT = "https://google.serper.dev/search"
_SERPER_IMAGES_ENDPOINT = "https://google.serper.dev/images"
_SERPER_MAX_RESULTS = 10
_api_key_warned: set[str] = set()
def _get_api_key(tool_name: str) -> str | None:
config = get_app_config().get_tool_config(tool_name)
if config is not None:
api_key = config.model_extra.get("api_key")
if isinstance(api_key, str) and api_key.strip():
return api_key.strip()
env_key = os.getenv("SERPER_API_KEY")
if isinstance(env_key, str) and env_key.strip():
return env_key.strip()
return None
def _coerce_max_results(value: object, default: int = 5, max_allowed: int = _SERPER_MAX_RESULTS) -> int:
"""Coerce config/parameter input into a bounded positive result count."""
try:
count = int(value)
except (TypeError, ValueError):
return default
if count <= 0:
return default
return min(count, max_allowed)
def _missing_key_error(query: str, tool_name: str) -> str:
if tool_name not in _api_key_warned:
_api_key_warned.add(tool_name)
logger.warning("Serper API key is not set for '%s'. Set SERPER_API_KEY in your environment or provide api_key in config.yaml. Sign up at https://serper.dev", tool_name)
return json.dumps(
{"error": "SERPER_API_KEY is not configured", "query": query},
ensure_ascii=False,
)
def _unexpected_format_error(query: str) -> str:
return json.dumps(
{"error": "Serper returned an unexpected response format", "query": query},
ensure_ascii=False,
)
def _response_items(data: dict, field: str, query: str) -> tuple[list[dict] | None, str | None]:
items = data.get(field)
# Treat a missing or null field as "no results" (some APIs return
# ``{"organic": null}`` to signal that) rather than a malformed payload.
if items is None:
return [], None
if not isinstance(items, list):
logger.error("Serper returned unexpected '%s' payload type: %s", field, type(items).__name__)
return None, _unexpected_format_error(query)
return [item for item in items if isinstance(item, dict)], None
def _clean_query(query: str) -> str:
"""Normalize a raw query into the value actually sent to Serper."""
query = query.strip()
if len(query) > 500:
query = query[:500]
return query
def _decode_ipv4(host: str) -> IPv4Address | None:
"""Decode obfuscated IPv4 literals that ``ip_address`` rejects.
Mirrors the permissive ``inet_aton`` parsing many HTTP clients use, so that
integer (``2130706433``), hex (``0x7f000001``) and octal (``0177.0.0.1``)
encodings of an address are recognized. Returns an ``IPv4Address`` when the
host decodes to one, otherwise ``None`` (e.g. real domains like
``cafe.com`` fail to decode and are left for the caller to treat as a host).
"""
parts = host.split(".")
if not 1 <= len(parts) <= 4:
return None
values: list[int] = []
for part in parts:
if not part:
return None
try:
if part.startswith(("0x", "0X")):
values.append(int(part, 16))
elif part.startswith("0") and len(part) > 1:
values.append(int(part, 8))
else:
values.append(int(part, 10))
except ValueError:
return None
*leading, last = values
for value in leading:
if not 0 <= value <= 0xFF:
return None
max_last = (1 << (8 * (4 - len(leading)))) - 1
if not 0 <= last <= max_last:
return None
result = 0
for value in leading:
result = (result << 8) | value
result = (result << (8 * (4 - len(leading)))) | last
return ip_address(result)
def _is_url_present(value: object) -> bool:
"""Return ``True`` when *value* is a non-empty URL string.
Used to distinguish a field that was *absent* (eligible for cross-field
fallback) from one that was *present but filtered* by the SSRF guard (which
must stay empty rather than collapse onto its counterpart).
"""
return isinstance(value, str) and bool(value.strip())
def _safe_public_url(value: object) -> str:
"""Return ``value`` only if it is a safe, public http(s) URL, else "".
This is a best-effort SSRF guard that rejects non-http(s) schemes,
``localhost``, and private/non-global IP literals (including obfuscated
decimal/hex/octal encodings). It only inspects the URL string and cannot
catch public hostnames that resolve to internal IPs (e.g. DNS rebinding);
any consumer that actually downloads these URLs must re-validate the
resolved IP at fetch time.
"""
if not isinstance(value, str):
return ""
url = value.strip()
parsed = urlparse(url)
if parsed.scheme not in {"http", "https"} or not parsed.netloc or not parsed.hostname:
return ""
# Strip a single trailing dot (FQDN root label). ``localhost.`` and
# ``127.0.0.1.`` resolve to loopback on common resolvers but would
# otherwise slip past the localhost/IP checks below.
host = parsed.hostname.lower().rstrip(".")
if not host:
return ""
if host == "localhost" or host.endswith(".localhost"):
return ""
try:
ip = ip_address(host)
except ValueError:
ip = _decode_ipv4(host)
if ip is None:
return url
return url if ip.is_global else ""
def _serper_post(endpoint: str, api_key: str, query: str, max_results: int) -> tuple[dict | None, str | None]:
"""Send a POST request to a Serper endpoint.
``query`` is expected to already be normalized via :func:`_clean_query`.
Returns a ``(data, error_json)`` tuple: on success ``data`` is the parsed
JSON response and ``error_json`` is ``None``; on failure ``data`` is ``None``
and ``error_json`` is a serialized structured error ready to return.
"""
headers = {
"X-API-KEY": api_key,
"Content-Type": "application/json",
}
payload = {"q": query, "num": max_results}
try:
with httpx.Client(timeout=30) as client:
response = client.post(endpoint, headers=headers, json=payload)
response.raise_for_status()
data = response.json()
if not isinstance(data, dict):
logger.error("Serper returned an unexpected payload type: %s", type(data).__name__)
return None, _unexpected_format_error(query)
return data, None
except httpx.HTTPStatusError as e:
resp_text = (e.response.text or "")[:500]
logger.error("Serper API returned HTTP %s: %s", e.response.status_code, resp_text)
return None, json.dumps(
{"error": f"Serper API error: HTTP {e.response.status_code}", "query": query},
ensure_ascii=False,
)
except Exception as e:
logger.error("Serper request failed: %s: %s", type(e).__name__, str(e)[:500])
return None, json.dumps({"error": str(e)[:500], "query": query}, ensure_ascii=False)
@tool("web_search", parse_docstring=True)
def web_search_tool(query: str, max_results: int = 5) -> str:
"""Search the web for information using Google Search via Serper.
Args:
query: Search keywords describing what you want to find. Be specific for better results.
max_results: Maximum number of search results to return. Default is 5, capped at 10.
"""
config = get_app_config().get_tool_config("web_search")
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
max_results = _coerce_max_results(max_results)
query = _clean_query(query)
api_key = _get_api_key("web_search")
if not api_key:
return _missing_key_error(query, "web_search")
data, error_json = _serper_post(_SERPER_SEARCH_ENDPOINT, api_key, query, max_results)
if error_json is not None:
return error_json
organic, error_json = _response_items(data, "organic", query)
if error_json is not None:
return error_json
if not organic:
return json.dumps({"error": "No results found", "query": query}, ensure_ascii=False)
# Search result links are returned verbatim (not passed through
# _safe_public_url): they are surfaced as citations for the model to read,
# not fetched/downloaded by this tool, unlike image_search image URLs.
normalized_results = [
{
"title": r.get("title", ""),
"url": r.get("link", ""),
"content": r.get("snippet", ""),
}
for r in organic[:max_results]
]
output = {
"query": query,
"total_results": len(normalized_results),
"results": normalized_results,
}
return json.dumps(output, indent=2, ensure_ascii=False)
@tool("image_search", parse_docstring=True)
def image_search_tool(query: str, max_results: int = 5) -> str:
"""Search for images online using Google Images via Serper. Use this tool BEFORE image generation to find reference images for characters, portraits, objects, scenes, or any content requiring visual accuracy.
The returned image URLs can be used as reference images in image generation to significantly improve quality.
Args:
query: Search keywords describing the images you want to find. Be specific for better results (e.g., "Japanese woman street photography 1990s" instead of just "woman").
max_results: Maximum number of images to return. Default is 5, capped at 10.
"""
config = get_app_config().get_tool_config("image_search")
if config is not None and "max_results" in config.model_extra:
max_results = config.model_extra.get("max_results", max_results)
max_results = _coerce_max_results(max_results)
query = _clean_query(query)
api_key = _get_api_key("image_search")
if not api_key:
return _missing_key_error(query, "image_search")
data, error_json = _serper_post(_SERPER_IMAGES_ENDPOINT, api_key, query, max_results)
if error_json is not None:
return error_json
images, error_json = _response_items(data, "images", query)
if error_json is not None:
return error_json
if not images:
return json.dumps({"error": "No images found", "query": query}, ensure_ascii=False)
normalized_results = []
for r in images:
raw_image = r.get("imageUrl")
raw_thumb = r.get("thumbnailUrl")
# Evaluate the (non-trivial) SSRF guard once per field instead of twice.
safe_image = _safe_public_url(raw_image)
safe_thumb = _safe_public_url(raw_thumb)
# Cross-fall back only when the other field was *absent*. A field that
# was present but failed the SSRF filter is left empty rather than
# collapsed onto its counterpart, so a dropped high-res URL never
# silently masquerades as the preview (and vice versa), preserving the
# high-res/preview contract callers rely on.
image_url = safe_image or (safe_thumb if not _is_url_present(raw_image) else "")
thumbnail_url = safe_thumb or (safe_image if not _is_url_present(raw_thumb) else "")
if not image_url and not thumbnail_url:
continue
normalized_results.append(
{
"title": r.get("title", ""),
"image_url": image_url,
"thumbnail_url": thumbnail_url,
}
)
if len(normalized_results) >= max_results:
break
if not normalized_results:
return json.dumps({"error": "No safe image URLs found", "query": query}, ensure_ascii=False)
output = {
"query": query,
"total_results": len(normalized_results),
"results": normalized_results,
"usage_hint": "Use the 'image_url' values as reference images in image generation. Download them first if needed.",
}
return json.dumps(output, indent=2, ensure_ascii=False)