当前位置:网站首页>Global pooling – pytoch

Global pooling – pytoch

2022-06-11 10:22:00 liiiiiiiiiiiiike

Global average pooling

Convolution neural network can solve regression and classification problems , But the common convolution neural network has to achieve classification through the full connection layer , This actually leads to a surge in the number of neurons and the amount of computation in the TJ , Especially for some networks with high requirements for regression, it will bring some sequelae . So people have come up with different full connectivity layer solutions , Two of the most common are to roll up the final layer flatten Change to global maximum / Mean pooling , Compare the two ways , Here is the following :
 Insert picture description here
You can see that global pooling produces neurons as needed , The number of neurons can be controlled , Adjustable . and flatten The way is a hard link , Cannot be in flatten When adjusting the number of links . The most common way to pool global mean output is to put each channel feature map Output a neuron ( Mean results output ), Here is the following :
 Insert picture description here
The global maximum pooling diagram is shown below , It's to take each feature map The maximum of :
 Insert picture description here
The input of global mean pooling and global maximum pooling is generally NxCxHxW, Output is NxCx1x1 But in fact, sometimes we have another need , It's global deep pooling , Its output is Nx1xHxW. This method of pooling usually converts data into NxH*WxC The way , Then use one dimension maximum / Mean pooling in C The complete , Finally, it is converted to Nx1xHxW that will do . After understanding several common global pooling methods , Let's take a look at it Pytorch Related functions supported in .

Pytorch Global pooling code demo

  • Global maximum pooling
torch.nn.AdaptiveMaxPool2d(output_size, return_indices=False)
  • Global average pooling
torch.nn.AdaptiveAvgPool2d(output_size)

among output_size Indicative output HxW Normal set to HxW=1x1=(1, 1). The code is shown as follows :

##  Input is N x C x H x W=1 x 8 x 4 x 4
#  Global mean pooling 
 avg_pooling = torch.nn.AdaptiveAvgPool2d((1,1))
 B, C, H, W = input.size()
 output = avg_pooling(input).view(B, -1)
 print(" Global mean pooling :", output.size())
 print(output, "\n")

 #  Global maximum pooling 
 avg_pooling = torch.nn.AdaptiveMaxPool2d((1, 1))
 B, C, H, W = input.size()
 output = avg_pooling(input).view(B, -1)
 print(" Global maximum pooling :", output.size())
 print(output, "\n")

 Insert picture description here

  • Global deep pooling : Change the feature graph dimension to any dimension , And the width and height are 1
class DeepWise_Pool(torch.nn.MaxPool1d):
     def __init__(self, channels, isize):
         super(DeepWise_Pool, self).__init__(channels)
         self.kernel_size = channels
         self.stride = isize

     def forward(self, input):
         n, c, w, h = input.size()
         input = input.view(n,c,w*h).permute(0,2,1)
         pooled =  torch.nn.functional.max_pool1d(input, self.kernel_size, self.stride,
                         self.padding, self.dilation, self.ceil_mode,
                         self.return_indices)
         _, _, c = pooled.size()
         pooled = pooled.permute(0,2,1)
         return pooled.view(n,c,w,h).view(w, h)


 input = torch.randn(1, 8, 4, 4)
 print("input data:/n", input)
 print("input data:", input.size())
 B, C, W, H = input.size()
 dw_max_pool = DeepWise_Pool(C, W*H)
 output = dw_max_pool(input)
 print(" Global deep pooling :", output.view(-1, 16).size())
 print(output, "\n")

 Insert picture description here

Global pooling benefits :

CNN In image classification , Take convolution as feature extraction ,FC layer +softmax As a regression classification , Weakness is FC Too many layer neurons are easy to over fit , It's usually used dropout Improve over fitting . But there are still many parameters , The defect of slow speed . It can be used GAP To avoid FC The treatment of layer , Directly through global pooling +softmax To classify , Parameter quantity is low . There is another advantage : Global pooling also partially preserves the spatial structure information of the input image !

原网站

版权声明
本文为[liiiiiiiiiiiiike]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/162/202206110917083523.html