refactor(tool-search): consolidate MCP metadata tag and harden deferred-tool setup (#3370)

Follow-up to #3342 (deferred MCP tool loading). Maintainability cleanup plus
hardening of malformed/empty tool_search queries; no change to the deferral
mechanism or search ranking.

- Add deerflow/tools/mcp_metadata.py as the single source of truth for the
  "deerflow_mcp" tag (MCP_TOOL_METADATA_KEY + tag_mcp_tool + public
  is_mcp_tool). Removes the duplicated magic string and the private,
  cross-module _is_mcp_tool import.
- tool_search.search: never raise on model-generated input. Extract
  _compile_catalog_regex (shared compile-with-literal-fallback); return empty
  for empty/whitespace queries and a bare "+" instead of matching everything
  or raising IndexError.
- DeferredToolSetup: document the empty-vs-populated invariant.
- build_deferred_tool_setup: comment the two distinct empty-return branches.
- _assemble_deferred: add return type, rename local to deferred_setup, build
  the final list with an explicit append.
- Tests: use tag_mcp_tool instead of per-file tag helpers; cover empty and
  bare-"+" queries.
This commit is contained in:
AochenShen99
2026-06-05 15:21:41 +08:00
committed by GitHub
parent 28b1da2172
commit 2bbc7879fa
8 changed files with 123 additions and 47 deletions
@@ -18,7 +18,10 @@ middleware, and the async path inside ``TitleMiddleware``. Any new in-graph
``create_chat_model`` call must add to this list and pass the flag. ``create_chat_model`` call must add to this list and pass the flag.
""" """
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING
from langchain.agents import create_agent from langchain.agents import create_agent
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
@@ -45,6 +48,11 @@ from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.tracing import build_tracing_callbacks from deerflow.tracing import build_tracing_callbacks
if TYPE_CHECKING:
from langchain.tools import BaseTool
from deerflow.tools.builtins.tool_search import DeferredToolSetup
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -356,7 +364,7 @@ def _build_middlewares(
return middlewares return middlewares
def _assemble_deferred(filtered_tools, *, enabled: bool): def _assemble_deferred(filtered_tools: list[BaseTool], *, enabled: bool) -> tuple[list[BaseTool], DeferredToolSetup]:
"""Build the final tool list + deferred setup from a policy-filtered list. """Build the final tool list + deferred setup from a policy-filtered list.
Call AFTER tool-policy filtering so the deferred catalog never exposes a Call AFTER tool-policy filtering so the deferred catalog never exposes a
@@ -364,13 +372,16 @@ def _assemble_deferred(filtered_tools, *, enabled: bool):
and MCP tools survived filtering but no deferred set was recovered, raise and MCP tools survived filtering but no deferred set was recovered, raise
rather than silently binding their full schemas to the model. rather than silently binding their full schemas to the model.
""" """
from deerflow.tools.builtins.tool_search import _is_mcp_tool, build_deferred_tool_setup from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import is_mcp_tool
setup = build_deferred_tool_setup(filtered_tools, enabled=enabled) deferred_setup = build_deferred_tool_setup(filtered_tools, enabled=enabled)
if enabled and not setup.deferred_names and any(_is_mcp_tool(t) for t in filtered_tools): if enabled and not deferred_setup.deferred_names and any(is_mcp_tool(t) for t in filtered_tools):
raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).") raise RuntimeError("tool_search enabled and MCP tools survived policy filtering, but no deferred set was recovered — refusing to bind MCP schemas (fail-closed).")
final_tools = list(filtered_tools) + ([setup.tool_search_tool] if setup.tool_search_tool else []) final_tools = list(filtered_tools)
return final_tools, setup if deferred_setup.tool_search_tool:
final_tools.append(deferred_setup.tool_search_tool)
return final_tools, deferred_setup
def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None: def _available_skill_names(agent_config, is_bootstrap: bool) -> set[str] | None:
@@ -28,11 +28,25 @@ from langchain_core.tools import InjectedToolCallId, tool
from langchain_core.utils.function_calling import convert_to_openai_function from langchain_core.utils.function_calling import convert_to_openai_function
from langgraph.types import Command from langgraph.types import Command
from deerflow.tools.mcp_metadata import is_mcp_tool
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MAX_RESULTS = 5 # Max tools returned per search MAX_RESULTS = 5 # Max tools returned per search
def _compile_catalog_regex(pattern: str) -> re.Pattern[str]:
"""Compile ``pattern`` case-insensitively, falling back to a literal match.
Search queries come from the model, so an invalid regex (e.g. an unbalanced
paren) must degrade to a literal substring match rather than raise.
"""
try:
return re.compile(pattern, re.IGNORECASE)
except re.error:
return re.compile(re.escape(pattern), re.IGNORECASE)
# ── Catalog ── # ── Catalog ──
@@ -56,22 +70,25 @@ class DeferredToolCatalog:
return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16] return hashlib.sha256(blob.encode("utf-8")).hexdigest()[:16]
def search(self, query: str) -> list[BaseTool]: def search(self, query: str) -> list[BaseTool]:
query = query.strip()
if not query:
return []
if query.startswith("select:"): if query.startswith("select:"):
wanted = {n.strip() for n in query[7:].split(",")} wanted = {n.strip() for n in query[7:].split(",")}
return [t for t in self.tools if t.name in wanted][:MAX_RESULTS] return [t for t in self.tools if t.name in wanted][:MAX_RESULTS]
if query.startswith("+"): if query.startswith("+"):
parts = query[1:].split(None, 1) parts = query[1:].split(None, 1)
if not parts:
return [] # bare "+" with no required token — nothing to require
required = parts[0].lower() required = parts[0].lower()
candidates = [t for t in self.tools if required in t.name.lower()] candidates = [t for t in self.tools if required in t.name.lower()]
if len(parts) > 1: if len(parts) > 1:
candidates.sort(key=lambda t: _catalog_regex_score(parts[1], t), reverse=True) candidates.sort(key=lambda t: _catalog_regex_score(parts[1], t), reverse=True)
return candidates[:MAX_RESULTS] return candidates[:MAX_RESULTS]
try: regex = _compile_catalog_regex(query)
regex = re.compile(query, re.IGNORECASE)
except re.error:
regex = re.compile(re.escape(query), re.IGNORECASE)
scored: list[tuple[int, BaseTool]] = [] scored: list[tuple[int, BaseTool]] = []
for t in self.tools: for t in self.tools:
searchable = f"{t.name} {t.description or ''}" searchable = f"{t.name} {t.description or ''}"
@@ -82,10 +99,7 @@ class DeferredToolCatalog:
def _catalog_regex_score(pattern: str, t: BaseTool) -> int: def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
try: regex = _compile_catalog_regex(pattern)
regex = re.compile(pattern, re.IGNORECASE)
except re.error:
regex = re.compile(re.escape(pattern), re.IGNORECASE)
return len(regex.findall(f"{t.name} {t.description or ''}")) return len(regex.findall(f"{t.name} {t.description or ''}"))
@@ -94,15 +108,25 @@ def _catalog_regex_score(pattern: str, t: BaseTool) -> int:
@dataclass(frozen=True) @dataclass(frozen=True)
class DeferredToolSetup: class DeferredToolSetup:
"""Result of assembling deferred-tool support for one agent build.
The three fields move as a unit, so callers branch on ``tool_search_tool``:
- **Empty** ``(None, frozenset(), None)``: deferral is disabled, or no MCP
tool survived policy filtering. Nothing is deferred — bind tools as-is.
- **Populated**: ``tool_search_tool`` is appended to the agent's tools,
``deferred_names`` are withheld from the model until promoted, and
``catalog_hash`` scopes those promotions in graph state.
Invariant: ``tool_search_tool is None`` ⟺ ``deferred_names`` is empty ⟺
``catalog_hash is None``.
"""
tool_search_tool: BaseTool | None tool_search_tool: BaseTool | None
deferred_names: frozenset[str] deferred_names: frozenset[str]
catalog_hash: str | None catalog_hash: str | None
def _is_mcp_tool(t: BaseTool) -> bool:
return (getattr(t, "metadata", None) or {}).get("deerflow_mcp") is True
def build_tool_search_tool(catalog: DeferredToolCatalog) -> BaseTool: def build_tool_search_tool(catalog: DeferredToolCatalog) -> BaseTool:
catalog_hash = catalog.hash catalog_hash = catalog.hash
@@ -141,11 +165,17 @@ def build_deferred_tool_setup(filtered_tools: list[BaseTool], *, enabled: bool)
Must be called after skill/agent tool-policy filtering so the catalog never Must be called after skill/agent tool-policy filtering so the catalog never
exposes a tool the current agent is not allowed to use. exposes a tool the current agent is not allowed to use.
Returns an empty setup (see :class:`DeferredToolSetup`) in two distinct
cases: deferral is disabled, or it is enabled but no MCP tool survived
filtering.
""" """
if not enabled: if not enabled:
# Deferral disabled: defer nothing; the model binds every tool as before.
return DeferredToolSetup(None, frozenset(), None) return DeferredToolSetup(None, frozenset(), None)
deferred = [t for t in filtered_tools if _is_mcp_tool(t)] deferred = [t for t in filtered_tools if is_mcp_tool(t)]
if not deferred: if not deferred:
# Enabled, but no MCP tool to defer: same empty result, different reason.
return DeferredToolSetup(None, frozenset(), None) return DeferredToolSetup(None, frozenset(), None)
catalog = DeferredToolCatalog(tuple(deferred)) catalog = DeferredToolCatalog(tuple(deferred))
return DeferredToolSetup(build_tool_search_tool(catalog), catalog.names, catalog.hash) return DeferredToolSetup(build_tool_search_tool(catalog), catalog.names, catalog.hash)
@@ -0,0 +1,29 @@
"""Single source of truth for the MCP-tool metadata tag.
A tool is "MCP-sourced" when it carries the ``deerflow_mcp`` metadata flag.
The tag is *written* where MCP tools are loaded (``tools.py``) and *read* by
deferred-tool assembly (``tool_search.py``) and the agent build site
(``agent.py``). Keeping the key, the tagger, and the predicate here means the
magic string lives in exactly one place, and readers import a public predicate
instead of a private cross-module helper.
This is a leaf module by design: it depends only on ``BaseTool`` so that any
module (including the tool loader) can import it without an import cycle.
"""
from __future__ import annotations
from langchain.tools import BaseTool
MCP_TOOL_METADATA_KEY = "deerflow_mcp"
def tag_mcp_tool(tool: BaseTool) -> BaseTool:
"""Mark ``tool`` as MCP-sourced. Mutates in place and returns it for chaining."""
tool.metadata = {**(tool.metadata or {}), MCP_TOOL_METADATA_KEY: True}
return tool
def is_mcp_tool(tool: BaseTool) -> bool:
"""True when ``tool`` carries the MCP-source tag written by :func:`tag_mcp_tool`."""
return (getattr(tool, "metadata", None) or {}).get(MCP_TOOL_METADATA_KEY) is True
@@ -7,6 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.mcp_metadata import tag_mcp_tool
from deerflow.tools.sync import make_sync_tool_wrapper from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -132,7 +133,7 @@ def get_available_tools(
# the deferred catalog + tool_search tool are assembled per # the deferred catalog + tool_search tool are assembled per
# agent from the policy-filtered tool list. # agent from the policy-filtered tool list.
for t in mcp_tools: for t in mcp_tools:
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True} tag_mcp_tool(t)
except ImportError: except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.") logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e: except Exception as e:
+17
View File
@@ -54,6 +54,23 @@ def test_search_invalid_regex_falls_back_to_literal():
assert cat.search("zzz(") == [] assert cat.search("zzz(") == []
def test_search_empty_query_returns_empty(catalog):
# An empty / whitespace-only query is meaningless; rather than let the empty
# regex match every tool, search() returns nothing so the model gets a clear
# "no match" signal and re-queries instead of acting on noise.
assert catalog.search("") == []
assert catalog.search(" ") == []
def test_search_bare_plus_returns_empty(catalog):
# A "+" prefix with no required token is malformed model input. It must
# return no matches, not raise IndexError on parts[0]. " + " strips to "+",
# so it routes here too and must be handled the same way.
assert catalog.search("+") == []
assert catalog.search(" + ") == []
assert catalog.search("+ ") == []
def test_hash_stable_across_instances(): def test_hash_stable_across_instances():
c1 = DeferredToolCatalog((alpha_search, beta_translate)) c1 = DeferredToolCatalog((alpha_search, beta_translate))
c2 = DeferredToolCatalog((beta_translate, alpha_search)) c2 = DeferredToolCatalog((beta_translate, alpha_search))
@@ -20,6 +20,7 @@ from langchain_core.tools import tool as as_tool
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.agents.thread_state import ThreadState from deerflow.agents.thread_state import ThreadState
from deerflow.tools.builtins.tool_search import build_deferred_tool_setup from deerflow.tools.builtins.tool_search import build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool @as_tool
@@ -40,11 +41,6 @@ def mcp_other(x: str) -> str:
return x return x
def _tag(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
def test_tool_search_promotes_into_next_turn(): def test_tool_search_promotes_into_next_turn():
bound: list[list[str]] = [] bound: list[list[str]] = []
@@ -53,7 +49,7 @@ def test_tool_search_promotes_into_next_turn():
bound.append([getattr(t, "name", None) for t in tools]) bound.append([getattr(t, "name", None) for t in tools])
return self return self
setup = build_deferred_tool_setup([active_tool, _tag(mcp_calc), _tag(mcp_other)], enabled=True) setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_calc), tag_mcp_tool(mcp_other)], enabled=True)
turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}]) turn1 = AIMessage(content="", tool_calls=[{"name": "tool_search", "args": {"query": "select:mcp_calc"}, "id": "c1", "type": "tool_call"}])
turn2 = AIMessage(content="done") turn2 = AIMessage(content="done")
model = RecordingModel(messages=iter([turn1, turn2])) model = RecordingModel(messages=iter([turn1, turn2]))
+6 -10
View File
@@ -1,7 +1,8 @@
from langchain_core.tools import tool as as_tool from langchain_core.tools import tool as as_tool
from langgraph.types import Command from langgraph.types import Command
from deerflow.tools.builtins.tool_search import DeferredToolCatalog, _is_mcp_tool, build_deferred_tool_setup, build_tool_search_tool from deerflow.tools.builtins.tool_search import DeferredToolCatalog, build_deferred_tool_setup, build_tool_search_tool
from deerflow.tools.mcp_metadata import is_mcp_tool, tag_mcp_tool
@as_tool @as_tool
@@ -16,18 +17,13 @@ def local_echo(text: str) -> str:
return text return text
def _tag_mcp(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
def test_is_mcp_tool_reads_metadata(): def test_is_mcp_tool_reads_metadata():
assert _is_mcp_tool(_tag_mcp(mcp_calc)) is True assert is_mcp_tool(tag_mcp_tool(mcp_calc)) is True
assert _is_mcp_tool(local_echo) is False assert is_mcp_tool(local_echo) is False
def test_setup_disabled_returns_empty(): def test_setup_disabled_returns_empty():
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=False) setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=False)
assert setup.tool_search_tool is None assert setup.tool_search_tool is None
assert setup.deferred_names == frozenset() assert setup.deferred_names == frozenset()
assert setup.catalog_hash is None assert setup.catalog_hash is None
@@ -40,7 +36,7 @@ def test_setup_no_mcp_returns_empty():
def test_setup_builds_from_mcp_survivors(): def test_setup_builds_from_mcp_survivors():
setup = build_deferred_tool_setup([_tag_mcp(mcp_calc), local_echo], enabled=True) setup = build_deferred_tool_setup([tag_mcp_tool(mcp_calc), local_echo], enabled=True)
assert setup.deferred_names == frozenset({"mcp_calc"}) assert setup.deferred_names == frozenset({"mcp_calc"})
assert setup.tool_search_tool is not None assert setup.tool_search_tool is not None
assert setup.tool_search_tool.name == "tool_search" assert setup.tool_search_tool.name == "tool_search"
@@ -23,6 +23,7 @@ from deerflow.agents.middlewares.deferred_tool_filter_middleware import Deferred
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
from deerflow.skills.types import Skill from deerflow.skills.types import Skill
from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup from deerflow.tools.builtins.tool_search import DeferredToolSetup, build_deferred_tool_setup
from deerflow.tools.mcp_metadata import tag_mcp_tool
@as_tool @as_tool
@@ -37,11 +38,6 @@ def mcp_secret(x: str) -> str:
return x return x
def _tag(t):
t.metadata = {**(t.metadata or {}), "deerflow_mcp": True}
return t
_BOUND: list[list[str]] = [] _BOUND: list[list[str]] = []
@@ -52,7 +48,7 @@ class _RecordingModel(GenericFakeChatModel):
def _build_graph(): def _build_graph():
filtered = [active_tool, _tag(mcp_secret)] filtered = [active_tool, tag_mcp_tool(mcp_secret)]
setup = build_deferred_tool_setup(filtered, enabled=True) setup = build_deferred_tool_setup(filtered, enabled=True)
final = [*filtered, setup.tool_search_tool] final = [*filtered, setup.tool_search_tool]
model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4)) model = _RecordingModel(messages=iter([AIMessage(content="done")] * 4))
@@ -107,18 +103,18 @@ def test_fail_closed_when_mcp_survives_without_setup(monkeypatch):
lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None), lambda tools, *, enabled: DeferredToolSetup(None, frozenset(), None),
) )
with pytest.raises(RuntimeError, match="fail-closed"): with pytest.raises(RuntimeError, match="fail-closed"):
agentmod._assemble_deferred([_tag(mcp_secret)], enabled=True) agentmod._assemble_deferred([tag_mcp_tool(mcp_secret)], enabled=True)
def test_subagent_reentry_does_not_touch_lead_state(): def test_subagent_reentry_does_not_touch_lead_state():
"""#2884: building a second (subagent) setup must not affect the lead's """#2884: building a second (subagent) setup must not affect the lead's
middleware. With no shared registry/ContextVar, the lead middleware depends middleware. With no shared registry/ContextVar, the lead middleware depends
only on its own deferred_names + the passed state.""" only on its own deferred_names + the passed state."""
lead_setup = build_deferred_tool_setup([active_tool, _tag(mcp_secret)], enabled=True) lead_setup = build_deferred_tool_setup([active_tool, tag_mcp_tool(mcp_secret)], enabled=True)
mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash) mw = DeferredToolFilterMiddleware(lead_setup.deferred_names, lead_setup.catalog_hash)
# Simulate a subagent build re-entering tool assembly with its own setup. # Simulate a subagent build re-entering tool assembly with its own setup.
_ = build_deferred_tool_setup([_tag(mcp_secret)], enabled=True) _ = build_deferred_tool_setup([tag_mcp_tool(mcp_secret)], enabled=True)
class _Req: class _Req:
def __init__(self): def __init__(self):
@@ -154,7 +150,7 @@ def test_policy_denied_mcp_yields_no_tool_search_end_to_end():
tool_search (and does not fail-closed, because no MCP tool leaked through).""" tool_search (and does not fail-closed, because no MCP tool leaked through)."""
from deerflow.agents.lead_agent import agent as agentmod from deerflow.agents.lead_agent import agent as agentmod
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(["active_tool"])]) filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(["active_tool"])])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True) final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
assert [t.name for t in final_tools] == ["active_tool"] assert [t.name for t in final_tools] == ["active_tool"]
@@ -174,7 +170,7 @@ def test_tool_search_appended_after_policy_but_never_exposes_denied_tool():
from deerflow.agents.lead_agent import agent as agentmod from deerflow.agents.lead_agent import agent as agentmod
allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search allowed = ["active_tool", "mcp_secret"] # permits the MCP tool, does NOT list tool_search
filtered = filter_tools_by_skill_allowed_tools([active_tool, _tag(mcp_secret)], [_make_skill(allowed)]) filtered = filter_tools_by_skill_allowed_tools([active_tool, tag_mcp_tool(mcp_secret)], [_make_skill(allowed)])
final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True) final_tools, setup = agentmod._assemble_deferred(filtered, enabled=True)
names = {t.name for t in final_tools} names = {t.name for t in final_tools}