fix: gate deferred MCP tool execution (#2513)

* fix: gate deferred MCP tool execution

* style: format deferred tool middleware

* fix: address deferred tool review feedback
This commit is contained in:
DanielWalnut
2026-04-24 22:45:41 +08:00
committed by GitHub
parent d78ed5c8f2
commit ec8a8cae38
3 changed files with 155 additions and 1 deletions
@@ -16,6 +16,9 @@ from typing import override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse
from langchain_core.messages import ToolMessage
from langgraph.prebuilt.tool_node import ToolCallRequest
from langgraph.types import Command
logger = logging.getLogger(__name__)
@@ -35,7 +38,7 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
if not registry:
return request
deferred_names = {e.name for e in registry.entries}
deferred_names = registry.deferred_names
active_tools = [t for t in request.tools if getattr(t, "name", None) not in deferred_names]
if len(active_tools) < len(request.tools):
@@ -43,6 +46,28 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
return request.override(tools=active_tools)
def _blocked_tool_message(self, request: ToolCallRequest) -> ToolMessage | None:
from deerflow.tools.builtins.tool_search import get_deferred_registry
registry = get_deferred_registry()
if not registry:
return None
tool_name = str(request.tool_call.get("name") or "")
if not tool_name:
return None
if not registry.contains(tool_name):
return None
tool_call_id = str(request.tool_call.get("id") or "missing_tool_call_id")
return ToolMessage(
content=(f"Error: Tool '{tool_name}' is deferred and has not been promoted yet. Call tool_search first to expose and promote this tool's schema, then retry."),
tool_call_id=tool_call_id,
name=tool_name,
status="error",
)
@override
def wrap_model_call(
self,
@@ -51,6 +76,17 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
) -> ModelCallResult:
return handler(self._filter_tools(request))
@override
def wrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], ToolMessage | Command],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return handler(request)
@override
async def awrap_model_call(
self,
@@ -58,3 +94,14 @@ class DeferredToolFilterMiddleware(AgentMiddleware[AgentState]):
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._filter_tools(request))
@override
async def awrap_tool_call(
self,
request: ToolCallRequest,
handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]],
) -> ToolMessage | Command:
blocked = self._blocked_tool_message(request)
if blocked is not None:
return blocked
return await handler(request)
@@ -112,6 +112,15 @@ class DeferredToolRegistry:
def entries(self) -> list[DeferredToolEntry]:
return list(self._entries)
@property
def deferred_names(self) -> set[str]:
"""Names of tools that are still hidden from model binding."""
return {entry.name for entry in self._entries}
def contains(self, name: str) -> bool:
"""Return whether *name* is still deferred."""
return any(entry.name == name for entry in self._entries)
def __len__(self) -> int:
return len(self._entries)