当前位置:网站首页>PyTorch 卷积网络正则化 DropBlock
PyTorch 卷积网络正则化 DropBlock
2022-07-03 01:51:00 【荷碧·TongZJ】
论文地址:https://arxiv.org/pdf/1810.12890.pdf
论文概要
DropBlock 是一种类似于 dropout 的简单方法,它与 dropout 的主要区别在于,它从层的特征图中抹除连续区域,而不是抹除独立的随机单元
类似地,DropBlock 通过随机地置零网络的响应,实现了通道之间的解耦,缓解了网络的过拟合现象
这个算法的伪代码如下:
- x:特征图,shape 为 [bs, ch, h, w]
- block_size:抹除连续区域的尺寸
- γ:伯努利分布的均值,用于选中抹除区域的中心点
- trainning:布尔类型,表明是 train 模式还是 eval 模式
def DropBlock(x, block_size, γ, trainning):
if trainning:
# 选中要抹除区域的中心点
del_mask = bernoulli(x, γ)
# 抹除相应的区域
x = set_zero(x, del_mask, block_size)
# 特征图标准化
keep_mask = 1 - del_mask
x *= count(x) / count_1(keep_mask)
return x
# eval 模式下没有任何行为
return x但是在具体实现的过程中,还有很多需要补充的细节

γ 的确定是通过 keep_prob 参数确定的,keep_prob 表示激活单元 (即输出大于 0) 被保留的概率,feat_size 为特征图的尺寸:

因为在训练刚开始时,较小的 keep_prob 会影响网络的收敛,所以令 keep_prob 从 1.0 渐渐降为 0.9
从实验结果可以看到,ResNet-50 在使用了 DropBlock 后在验证集上的准确率有一定的提升

以下是不同的 DropBlock 追加位置、不同的处理方法、不同 block_size 对验证集准确率的影响:
- 按行:DropBlock 追加在 ResNet-50 的第 4 组卷积后;DropBlock 追加在 ResNet-50 的第3、第 4 组卷积后
- 按列:只在卷积分支上追加;在卷积分支、残差连接分支上追加;在卷积分支、残差连接分支上追加,并使用 keep_prob 衰减的方法

在论文中,最优的超参数是 block_size = 7, keep_prob = 0.9,但实际使用时仍需要根据 Loss 的变化情况做出调整
DropBlock 复现
在实现 DropBlock 时,有以下几个细节:
- keep_prob 是动态变化的,令每次 eval 时进行更新
- 抹除区域的中心点是在激活单元中选择的 (即输出大于 0),令 1 表示被选中,使用 max_pool2d 可以实现连续区域的选中,以生成 del_mask
- 标准化系数 = 原图面积 / 保留区域面积,但是保留区域面积的精确值计算会耗费较多的算力,减缓网络训练的速度,所以标准化系数使用 1 / keep_prob 近似替代
class DropBlock(nn.Module):
''' block_size: 抹除区域的尺寸
keep_prob_init: keep_prob 的初始值
keep_prob_tar: keep_prob 的目标值
keep_prob_decay: keep_prob 的衰减速度'''
def __init__(self, block_size=5, keep_prob_init=1.,
keep_prob_tar=0.9, keep_prob_decay=1e-2):
super(DropBlock, self).__init__()
self.block_size = block_size
assert self.block_size & 1, 'block_size 需为奇数'
# keep_prob 相关参数
self.keep_prob = keep_prob_init
self._keep_prob_tar = keep_prob_tar
self._keep_prob_decay = keep_prob_decay
# 伯努利分布的均值
self.gamma = None
def forward(self, x):
# 训练模式下
if self.training:
*bs_ch, height, width = x.shape
square = height * width
# 当 γ 为空时设置
if self.gamma is None:
self.gamma = (1 - self.keep_prob) * square / self.block_size ** 2
for f_size in (height, width):
self.gamma /= f_size - self.block_size + 1
# 在激活区域中, 选择抹除区域的中心点
del_mask = torch.bernoulli((x > 0) * self.gamma)
keep_mask = 1 - torch.max_pool2d(
del_mask, kernel_size=self.block_size,
stride=1, padding=self.block_size // 2
)
# 特征图标准化
# gain = square / keep_mask.view(*bs_ch, -1).sum(2).view(*bs_ch, 1, 1)
return keep_mask * x / self.keep_prob
# 验证模式下, 更新参数
self.keep_prob = max([
self._keep_prob_tar,
self.keep_prob * (1 - self._keep_prob_decay)
])
self.gamma = None
return x代码测试
# 利用灰度图, 将亮度低的像素置为 0
image = cv.imread('YouXiZi.jpg')
mask = cv.cvtColor(image, cv.COLOR_BGR2GRAY) > 100
for i in range(3):
image[..., i] *= mask
cv.imshow('debug', image)
cv.waitKey(0)
# 转化为 tensor, 使用 DropBlock
tensor = tf.ToTensor()(image)
db = DropBlock(block_size=31, keep_prob_init=0.9)
image = db(tensor.unsqueeze(0))[0]
image = image.permute(1, 2, 0).data.numpy()
cv.imshow('debug', image)
cv.waitKey(0)利用灰度图将亮度暗的像素置零,亮区即为激活单元

抹除区域的中心点均出现在亮区内,而且图像的亮度相较于原图有一定提升 (标准化系数 > 1)

边栏推荐
- Answers to ten questions about automated testing software testers must see
- 机器学习笔记(持续更新中。。。)
- Function definition and call, this, strict mode, higher-order function, closure, recursion
- [error record] an error is reported in the fluent interface (no mediaquery widget ancestor found. | scaffold widgets require a mediaquery)
- 微信小程序開發工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理問題
- Learn BeanShell before you dare to say you know JMeter
- 【Camera专题】手把手撸一份驱动 到 点亮Camera
- stm32F407-------IIC通讯协议
- Analysis, use and extension of open source API gateway apisex
- 使用Go语言实现try{}catch{}finally
猜你喜欢

PS remove watermark details

The technology boss is ready, and the topic of position C is up to you

The testing process that software testers should know

In the face of difficult SQL requirements, HQL is not afraid

easyPOI
![[fluent] hero animation (hero animation use process | create hero animation core components | create source page | create destination page | page Jump)](/img/68/65b8c0530cfdc92ba4f583b0162544.gif)
[fluent] hero animation (hero animation use process | create hero animation core components | create source page | create destination page | page Jump)
![[camera topic] how to save OTP data in user-defined nodes](/img/3e/b76c4d6ef9ab5f5b4326a3a8aa1c4f.png)
[camera topic] how to save OTP data in user-defined nodes
![[Appendix 6 Application of reflection] Application of reflection: dynamic agent](/img/e7/0ee42902b178b13e9a41385267e7b6.jpg)
[Appendix 6 Application of reflection] Application of reflection: dynamic agent
![[shutter] shutter debugging (debugging control related functions | breakpoint management | code operation control)](/img/fe/c053f8d116eb307733177283a26318.png)
[shutter] shutter debugging (debugging control related functions | breakpoint management | code operation control)

stm32F407-------IIC通讯协议
随机推荐
小程序開發的部分功能
How can retail enterprises open the second growth curve under the full link digital transformation
DML Foundation
[Yu Yue education] Jiujiang University material analysis and testing technology reference
[shutter] shutter debugging (debugging control related functions | breakpoint management | code operation control)
《上市风云》荐书——唯勇气最可贵
[fluent] fluent debugging (debug debugging window | viewing mobile phone log information | setting normal breakpoints | setting expression breakpoints)
Network security OpenVAS
Leetcode 183 Customers who never order (2022.07.02)
Solution for processing overtime orders (Overtime unpaid)
Groovy, "try with resources" construction alternative
深度(穿透)选择器 ::v-deep/deep/及 > > >
Button button adaptive size of wechat applet
Method of removing webpage scroll bar and inner and outer margins
Depth (penetration) selector:: v-deep/deep/ and > > >
[Appendix 6 Application of reflection] Application of reflection: dynamic agent
Visualisation de l'ensemble de données au format yolov5 (fichier labelme json)
Network security - cracking system passwords
Technology sharing | Frida's powerful ability to realize hook functions
Network security - dynamic routing protocol rip