当前位置:网站首页>PyTorch 卷积网络正则化 DropBlock
PyTorch 卷积网络正则化 DropBlock
2022-07-03 01:51:00 【荷碧·TongZJ】
论文地址:https://arxiv.org/pdf/1810.12890.pdf
论文概要
DropBlock 是一种类似于 dropout 的简单方法,它与 dropout 的主要区别在于,它从层的特征图中抹除连续区域,而不是抹除独立的随机单元
类似地,DropBlock 通过随机地置零网络的响应,实现了通道之间的解耦,缓解了网络的过拟合现象
这个算法的伪代码如下:
- x:特征图,shape 为 [bs, ch, h, w]
- block_size:抹除连续区域的尺寸
- γ:伯努利分布的均值,用于选中抹除区域的中心点
- trainning:布尔类型,表明是 train 模式还是 eval 模式
def DropBlock(x, block_size, γ, trainning):
if trainning:
# 选中要抹除区域的中心点
del_mask = bernoulli(x, γ)
# 抹除相应的区域
x = set_zero(x, del_mask, block_size)
# 特征图标准化
keep_mask = 1 - del_mask
x *= count(x) / count_1(keep_mask)
return x
# eval 模式下没有任何行为
return x
但是在具体实现的过程中,还有很多需要补充的细节
γ 的确定是通过 keep_prob 参数确定的,keep_prob 表示激活单元 (即输出大于 0) 被保留的概率,feat_size 为特征图的尺寸:
因为在训练刚开始时,较小的 keep_prob 会影响网络的收敛,所以令 keep_prob 从 1.0 渐渐降为 0.9
从实验结果可以看到,ResNet-50 在使用了 DropBlock 后在验证集上的准确率有一定的提升
以下是不同的 DropBlock 追加位置、不同的处理方法、不同 block_size 对验证集准确率的影响:
- 按行:DropBlock 追加在 ResNet-50 的第 4 组卷积后;DropBlock 追加在 ResNet-50 的第3、第 4 组卷积后
- 按列:只在卷积分支上追加;在卷积分支、残差连接分支上追加;在卷积分支、残差连接分支上追加,并使用 keep_prob 衰减的方法
在论文中,最优的超参数是 block_size = 7, keep_prob = 0.9,但实际使用时仍需要根据 Loss 的变化情况做出调整
DropBlock 复现
在实现 DropBlock 时,有以下几个细节:
- keep_prob 是动态变化的,令每次 eval 时进行更新
- 抹除区域的中心点是在激活单元中选择的 (即输出大于 0),令 1 表示被选中,使用 max_pool2d 可以实现连续区域的选中,以生成 del_mask
- 标准化系数 = 原图面积 / 保留区域面积,但是保留区域面积的精确值计算会耗费较多的算力,减缓网络训练的速度,所以标准化系数使用 1 / keep_prob 近似替代
class DropBlock(nn.Module):
''' block_size: 抹除区域的尺寸
keep_prob_init: keep_prob 的初始值
keep_prob_tar: keep_prob 的目标值
keep_prob_decay: keep_prob 的衰减速度'''
def __init__(self, block_size=5, keep_prob_init=1.,
keep_prob_tar=0.9, keep_prob_decay=1e-2):
super(DropBlock, self).__init__()
self.block_size = block_size
assert self.block_size & 1, 'block_size 需为奇数'
# keep_prob 相关参数
self.keep_prob = keep_prob_init
self._keep_prob_tar = keep_prob_tar
self._keep_prob_decay = keep_prob_decay
# 伯努利分布的均值
self.gamma = None
def forward(self, x):
# 训练模式下
if self.training:
*bs_ch, height, width = x.shape
square = height * width
# 当 γ 为空时设置
if self.gamma is None:
self.gamma = (1 - self.keep_prob) * square / self.block_size ** 2
for f_size in (height, width):
self.gamma /= f_size - self.block_size + 1
# 在激活区域中, 选择抹除区域的中心点
del_mask = torch.bernoulli((x > 0) * self.gamma)
keep_mask = 1 - torch.max_pool2d(
del_mask, kernel_size=self.block_size,
stride=1, padding=self.block_size // 2
)
# 特征图标准化
# gain = square / keep_mask.view(*bs_ch, -1).sum(2).view(*bs_ch, 1, 1)
return keep_mask * x / self.keep_prob
# 验证模式下, 更新参数
self.keep_prob = max([
self._keep_prob_tar,
self.keep_prob * (1 - self._keep_prob_decay)
])
self.gamma = None
return x
代码测试
# 利用灰度图, 将亮度低的像素置为 0
image = cv.imread('YouXiZi.jpg')
mask = cv.cvtColor(image, cv.COLOR_BGR2GRAY) > 100
for i in range(3):
image[..., i] *= mask
cv.imshow('debug', image)
cv.waitKey(0)
# 转化为 tensor, 使用 DropBlock
tensor = tf.ToTensor()(image)
db = DropBlock(block_size=31, keep_prob_init=0.9)
image = db(tensor.unsqueeze(0))[0]
image = image.permute(1, 2, 0).data.numpy()
cv.imshow('debug', image)
cv.waitKey(0)
利用灰度图将亮度暗的像素置零,亮区即为激活单元
抹除区域的中心点均出现在亮区内,而且图像的亮度相较于原图有一定提升 (标准化系数 > 1)
边栏推荐
- Return the only different value (de duplication)
- 网络安全-NAT网络地址转换
- String replace space
- How to deal with cache hot key in redis
- [shutter] hero animation (hero realizes radial animation | hero component createrecttween setting)
- Network security - vulnerabilities and Trojans
- 2022 spring "golden three silver four" job hopping prerequisites: Software Test interview questions (with answers)
- [Yu Yue education] reference materials of chemical experiment safety knowledge of University of science and technology of China
- 网络安全-防火墙
- stm32F407-------DMA
猜你喜欢
Certaines fonctionnalités du développement d'applets
easyExcel
Redis:Redis的简单使用
Hard core observation 547 large neural network may be beginning to become aware?
微信小程序开发工具 POST net::ERR_PROXY_CONNECTION_FAILED 代理问题
【Camera专题】手把手撸一份驱动 到 点亮Camera
A 30-year-old software tester, who has been unemployed for 4 months, is confused and doesn't know what to do?
《上市风云》荐书——唯勇气最可贵
[shutter] bottom navigation bar implementation (bottomnavigationbar bottom navigation bar | bottomnavigationbaritem navigation bar entry | pageview)
Query product cases - page rendering data
随机推荐
Analyzing several common string library functions in C language
Problems encountered in small program development of dark horse shopping mall
In the face of difficult SQL requirements, HQL is not afraid
Types of map key and object key
What are MySQL locks and classifications
Redis:Redis的简单使用
小程序开发的部分功能
DQL basic operation
[camera special topic] Hal layer - brief analysis of addchannel and startchannel
Deep learning notes (constantly updating...)
stm32F407-------ADC
Rockchip3399 start auto load driver
Niuniu's ball guessing game (dynamic planning + prefix influence)
How to refresh the opening amount of Oracle ERP
[camera topic] complete analysis of camera dtsi
Reprint some Qt development experience written by great Xia 6.5
Visualisation de l'ensemble de données au format yolov5 (fichier labelme json)
Network security - cracking system passwords
Answers to ten questions about automated testing software testers must see
What are the key points often asked in the redis interview