当前位置:网站首页>PyTorch 卷积网络正则化 DropBlock
PyTorch 卷积网络正则化 DropBlock
2022-07-03 01:51:00 【荷碧·TongZJ】
论文地址:https://arxiv.org/pdf/1810.12890.pdf
论文概要
DropBlock 是一种类似于 dropout 的简单方法,它与 dropout 的主要区别在于,它从层的特征图中抹除连续区域,而不是抹除独立的随机单元
类似地,DropBlock 通过随机地置零网络的响应,实现了通道之间的解耦,缓解了网络的过拟合现象
这个算法的伪代码如下:
- x:特征图,shape 为 [bs, ch, h, w]
- block_size:抹除连续区域的尺寸
- γ:伯努利分布的均值,用于选中抹除区域的中心点
- trainning:布尔类型,表明是 train 模式还是 eval 模式
def DropBlock(x, block_size, γ, trainning):
if trainning:
# 选中要抹除区域的中心点
del_mask = bernoulli(x, γ)
# 抹除相应的区域
x = set_zero(x, del_mask, block_size)
# 特征图标准化
keep_mask = 1 - del_mask
x *= count(x) / count_1(keep_mask)
return x
# eval 模式下没有任何行为
return x但是在具体实现的过程中,还有很多需要补充的细节

γ 的确定是通过 keep_prob 参数确定的,keep_prob 表示激活单元 (即输出大于 0) 被保留的概率,feat_size 为特征图的尺寸:

因为在训练刚开始时,较小的 keep_prob 会影响网络的收敛,所以令 keep_prob 从 1.0 渐渐降为 0.9
从实验结果可以看到,ResNet-50 在使用了 DropBlock 后在验证集上的准确率有一定的提升

以下是不同的 DropBlock 追加位置、不同的处理方法、不同 block_size 对验证集准确率的影响:
- 按行:DropBlock 追加在 ResNet-50 的第 4 组卷积后;DropBlock 追加在 ResNet-50 的第3、第 4 组卷积后
- 按列:只在卷积分支上追加;在卷积分支、残差连接分支上追加;在卷积分支、残差连接分支上追加,并使用 keep_prob 衰减的方法

在论文中,最优的超参数是 block_size = 7, keep_prob = 0.9,但实际使用时仍需要根据 Loss 的变化情况做出调整
DropBlock 复现
在实现 DropBlock 时,有以下几个细节:
- keep_prob 是动态变化的,令每次 eval 时进行更新
- 抹除区域的中心点是在激活单元中选择的 (即输出大于 0),令 1 表示被选中,使用 max_pool2d 可以实现连续区域的选中,以生成 del_mask
- 标准化系数 = 原图面积 / 保留区域面积,但是保留区域面积的精确值计算会耗费较多的算力,减缓网络训练的速度,所以标准化系数使用 1 / keep_prob 近似替代
class DropBlock(nn.Module):
''' block_size: 抹除区域的尺寸
keep_prob_init: keep_prob 的初始值
keep_prob_tar: keep_prob 的目标值
keep_prob_decay: keep_prob 的衰减速度'''
def __init__(self, block_size=5, keep_prob_init=1.,
keep_prob_tar=0.9, keep_prob_decay=1e-2):
super(DropBlock, self).__init__()
self.block_size = block_size
assert self.block_size & 1, 'block_size 需为奇数'
# keep_prob 相关参数
self.keep_prob = keep_prob_init
self._keep_prob_tar = keep_prob_tar
self._keep_prob_decay = keep_prob_decay
# 伯努利分布的均值
self.gamma = None
def forward(self, x):
# 训练模式下
if self.training:
*bs_ch, height, width = x.shape
square = height * width
# 当 γ 为空时设置
if self.gamma is None:
self.gamma = (1 - self.keep_prob) * square / self.block_size ** 2
for f_size in (height, width):
self.gamma /= f_size - self.block_size + 1
# 在激活区域中, 选择抹除区域的中心点
del_mask = torch.bernoulli((x > 0) * self.gamma)
keep_mask = 1 - torch.max_pool2d(
del_mask, kernel_size=self.block_size,
stride=1, padding=self.block_size // 2
)
# 特征图标准化
# gain = square / keep_mask.view(*bs_ch, -1).sum(2).view(*bs_ch, 1, 1)
return keep_mask * x / self.keep_prob
# 验证模式下, 更新参数
self.keep_prob = max([
self._keep_prob_tar,
self.keep_prob * (1 - self._keep_prob_decay)
])
self.gamma = None
return x代码测试
# 利用灰度图, 将亮度低的像素置为 0
image = cv.imread('YouXiZi.jpg')
mask = cv.cvtColor(image, cv.COLOR_BGR2GRAY) > 100
for i in range(3):
image[..., i] *= mask
cv.imshow('debug', image)
cv.waitKey(0)
# 转化为 tensor, 使用 DropBlock
tensor = tf.ToTensor()(image)
db = DropBlock(block_size=31, keep_prob_init=0.9)
image = db(tensor.unsqueeze(0))[0]
image = image.permute(1, 2, 0).data.numpy()
cv.imshow('debug', image)
cv.waitKey(0)利用灰度图将亮度暗的像素置零,亮区即为激活单元

抹除区域的中心点均出现在亮区内,而且图像的亮度相较于原图有一定提升 (标准化系数 > 1)

边栏推荐
- y54.第三章 Kubernetes从入门到精通 -- ingress(二七)
- [camera special topic] Hal layer - brief analysis of addchannel and startchannel
- What are the differences between software testers with a monthly salary of 7K and 25K? Leaders look up to you when they master it
- Processing of tree structure data
- Network security ACL access control list
- 网络安全-动态路由协议RIP
- ByteDance data Lake integration practice based on Hudi
- How can retail enterprises open the second growth curve under the full link digital transformation
- 查询商品案例-页面渲染数据
- es6 filter() 数组过滤方法总结
猜你喜欢

Network security - vulnerabilities and Trojans

His experience in choosing a startup company or a big Internet company may give you some inspiration

Stm32f407 ------- IIC communication protocol

微信小程序開發工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理問題

Ni visa fails after LabVIEW installs the third-party visa software
![[camera topic] turn a drive to light up the camera](/img/d3/7aabaa5c75813abc4a43820b4c3706.png)
[camera topic] turn a drive to light up the camera

Certaines fonctionnalités du développement d'applets

詳細些介紹如何通過MQTT協議和華為雲物聯網進行通信

机器学习笔记(持续更新中。。。)

技术大佬准备就绪,话题C位由你决定
随机推荐
es6 filter() 数组过滤方法总结
[Yu Yue education] reference materials of chemical experiment safety knowledge of University of science and technology of China
Prohibited package name
Leetcode 183 Customers who never order (2022.07.02)
DDL basic operation
Method of removing webpage scroll bar and inner and outer margins
Machine learning notes (constantly updating...)
Network security - man in the middle attack
How to find summer technical internship in junior year? Are you looking for a large company or a small company for technical internship?
力扣(LeetCode)183. 从不订购的客户(2022.07.02)
stm32F407-------ADC
How to deal with cache hot key in redis
Depth (penetration) selector:: v-deep/deep/ and > > >
Custom components, using NPM packages, global data sharing, subcontracting
网络安全-NAT网络地址转换
可視化yolov5格式數據集(labelme json文件)
MySQL learning 03
Network security NAT network address translation
小程序开发黑马购物商城中遇到的问题
Introduction to kotlin collaboration