当前位置:网站首页>Pytorch convolution network regularization dropblock
Pytorch convolution network regularization dropblock
2022-07-03 02:11:00 【Hebi tongzj】
Address of thesis :https://arxiv.org/pdf/1810.12890.pdf
Paper Abstract
DropBlock It's something like dropout The easy way to , It is associated with dropout The main difference is , It erases the continuous area from the characteristic map of the layer , Instead of erasing independent random units
Similarly ,DropBlock By randomly zeroing the response of the network , Realize the decoupling between channels , It alleviates the over fitting phenomenon of the network
The pseudocode of this algorithm is as follows :
- x: Characteristics of figure ,shape by [bs, ch, h, w]
- block_size: Erase the size of the continuous area
- γ: The mean value of Bernoulli distribution , Used to select the center point of the erased area
- trainning: Boolean type , That is the train Mode or eval Pattern
def DropBlock(x, block_size, γ, trainning):
if trainning:
# Select the center point of the area to erase
del_mask = bernoulli(x, γ)
# Erase the corresponding area
x = set_zero(x, del_mask, block_size)
# Feature icon standardization
keep_mask = 1 - del_mask
x *= count(x) / count_1(keep_mask)
return x
# eval There is no behavior in mode
return xBut in the process of concrete implementation , There are many details that need to be added

γ The determination of is through keep_prob The parameters are determined ,keep_prob Indicates the activation unit ( That is, the output is greater than 0) The probability of being retained ,feat_size Is the dimension of the characteristic drawing :

Because at the beginning of training , smaller keep_prob It will affect the convergence of the network , So make keep_prob from 1.0 Gradually reduced to 0.9
From the experimental results, we can see ,ResNet-50 In the use of the DropBlock After that, the accuracy of the verification set has been improved

Here are the differences DropBlock Append position 、 Different approaches 、 Different block_size The impact on the accuracy of the validation set :
- Press the line :DropBlock Added in ResNet-50 Of the 4 After group convolution ;DropBlock Added in ResNet-50 Of the 3、 The first 4 After group convolution
- By column : Only add ; In convolution Branch 、 Add ; In convolution Branch 、 Add , And use keep_prob Attenuation method

In the paper , The optimal hyperparameter is block_size = 7, keep_prob = 0.9, But it still needs to be based on Loss Make adjustments to the changes
DropBlock Reappear
In the realization of DropBlock when , There are the following details :
- keep_prob It's dynamic , Make every time eval Update when
- The center point of the erased area is selected in the active unit ( That is, the output is greater than 0), Make 1 To be selected , Use max_pool2d It can realize the selection of continuous areas , To generate del_mask
- Standardization coefficient = Area of original drawing / Reserved area , But calculating the exact value of the reserved area will cost more computational effort , Slow down the speed of online training , So the standardization coefficient is 1 / keep_prob Approximate substitution
class DropBlock(nn.Module):
''' block_size: Erase the size of the area
keep_prob_init: keep_prob The initial value of the
keep_prob_tar: keep_prob The target value
keep_prob_decay: keep_prob Decay rate of '''
def __init__(self, block_size=5, keep_prob_init=1.,
keep_prob_tar=0.9, keep_prob_decay=1e-2):
super(DropBlock, self).__init__()
self.block_size = block_size
assert self.block_size & 1, 'block_size Need to be odd '
# keep_prob Related parameters
self.keep_prob = keep_prob_init
self._keep_prob_tar = keep_prob_tar
self._keep_prob_decay = keep_prob_decay
# The mean value of Bernoulli distribution
self.gamma = None
def forward(self, x):
# In training mode
if self.training:
*bs_ch, height, width = x.shape
square = height * width
# When γ Set for null
if self.gamma is None:
self.gamma = (1 - self.keep_prob) * square / self.block_size ** 2
for f_size in (height, width):
self.gamma /= f_size - self.block_size + 1
# In the activation area , Select the center point of the erased area
del_mask = torch.bernoulli((x > 0) * self.gamma)
keep_mask = 1 - torch.max_pool2d(
del_mask, kernel_size=self.block_size,
stride=1, padding=self.block_size // 2
)
# Feature icon standardization
# gain = square / keep_mask.view(*bs_ch, -1).sum(2).view(*bs_ch, 1, 1)
return keep_mask * x / self.keep_prob
# In verification mode , Update parameters
self.keep_prob = max([
self._keep_prob_tar,
self.keep_prob * (1 - self._keep_prob_decay)
])
self.gamma = None
return xCode testing
# Using grayscale images , Set the pixels with low brightness to 0
image = cv.imread('YouXiZi.jpg')
mask = cv.cvtColor(image, cv.COLOR_BGR2GRAY) > 100
for i in range(3):
image[..., i] *= mask
cv.imshow('debug', image)
cv.waitKey(0)
# Turn into tensor, Use DropBlock
tensor = tf.ToTensor()(image)
db = DropBlock(block_size=31, keep_prob_init=0.9)
image = db(tensor.unsqueeze(0))[0]
image = image.permute(1, 2, 0).data.numpy()
cv.imshow('debug', image)
cv.waitKey(0)Use the gray image to set the pixels with dark brightness to zero , The bright area is the active unit

The center point of the erased area appears in the bright area , And the brightness of the image is higher than that of the original image ( Standardization coefficient > 1)

边栏推荐
- Hard core observation 547 large neural network may be beginning to become aware?
- Basic operation of view
- File class (add / delete)
- RestCloud ETL 跨库数据聚合运算
- y54.第三章 Kubernetes从入门到精通 -- ingress(二七)
- Asian Games countdown! AI target detection helps host the Asian Games!
- Prohibited package name
- Method of removing webpage scroll bar and inner and outer margins
- Su Shimin: 25 principles of work and life
- Processing of tree structure data
猜你喜欢

Query product cases - page rendering data

In 2022, 95% of the three most common misunderstandings in software testing were recruited. Are you that 5%?

使用Go语言实现try{}catch{}finally

A 30-year-old software tester, who has been unemployed for 4 months, is confused and doesn't know what to do?

Flink CDC mongoDB 使用及Flink sql解析monggo中复杂嵌套JSON数据实现

微信小程序开发工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理问题

微信小程序開發工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理問題

Distributed transaction solution

Anna: Beibei, can you draw?
![[fluent] fluent debugging (debug debugging window | viewing mobile phone log information | setting normal breakpoints | setting expression breakpoints)](/img/ac/bf83f319ea787c5abd7ac3fabc9ede.jpg)
[fluent] fluent debugging (debug debugging window | viewing mobile phone log information | setting normal breakpoints | setting expression breakpoints)
随机推荐
机器学习笔记(持续更新中。。。)
创建+注册 子应用_定义路由,全局路由与子路由
stm32F407-------ADC
elastic stack
Return the only different value (de duplication)
udp接收队列以及多次初始化的测试
【Camera专题】Camera dtsi 完全解析
深度(穿透)选择器 ::v-deep/deep/及 > > >
502 (bad gateway) causes and Solutions
Leetcode (540) -- a single element in an ordered array
疫情当头,作为Leader如何进行团队的管理?| 社区征文
es6 filter() 数组过滤方法总结
Button button adaptive size of wechat applet
PyTorch 卷积网络正则化 DropBlock
[camera topic] how to save OTP data in user-defined nodes
His experience in choosing a startup company or a big Internet company may give you some inspiration
Comment communiquer avec Huawei Cloud IOT via le Protocole mqtt
微信小程序开发工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理问题
Reprint some Qt development experience written by great Xia 6.5
Bottleneck period must see: how can testers who have worked for 3-5 years avoid detours and break through smoothly