当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-07-31 17:31:00 【GIS与Climate】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
边栏推荐
猜你喜欢
Go1.18升级功能 - 模糊测试Fuzz 从零开始Go语言
35 MySQL interview questions and diagrams, this is also easy to understand
Chinese encoding Settings and action methods return values
21.支持向量机—核函数的介绍
动态规划之线性dp(上)
这位985教授火了!当了10年博导,竟无一博士毕业!
新型电信“套路”,我爸中招了!
MySQL - multi-table query
九齐ny3p系列语音芯片替代国产方案KT148A性价比更高420秒长度
Go basic part study notes
随机推荐
IP协议从0到1
牛客 HJ3 明明的随机数
Flutter 获取状态栏statusbar的高度
最新神作!阿里巴巴刚出炉的面试参考指南(泰山版),我直接狂刷29天
阿里三面:MQ 消息丢失、重复、积压问题,如何解决?
Flex布局详解
Golang 小数操作之判断几位小数点与四舍五入
京东按关键字搜索商品 API
多主复制下处理写冲突(3)-收敛至一致的状态及自定义冲突解决逻辑
The server encountered an internal error that prevented it from fulfilling this request的一种解决办法[通俗易懂]
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
最后写入胜利(丢弃并发写入)
【码蹄集新手村600题】不通过字符数组来合并俩个数字
牛客网刷题(二)
LevelSequence源码分析
go mode tidy出现报错go warning “all“ matched no packages
Concurrency, Timing and Relativity
35道MySQL面试必问题图解,这样也太好理解了吧
MySQL---aggregate function
并发性,时间和相对性