当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-07-31 17:31:00 【GIS与Climate】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
边栏推荐
- 新型电信“套路”,我爸中招了!
- 九齐ny3p系列语音芯片替代国产方案KT148A性价比更高420秒长度
- The article you worked so hard to write may not be your original
- Taobao/Tmall get Taobao password real url API
- 杰理语音芯片ic玩具芯片ic的介绍_AD14NAD15N全系列开发
- MySQL---operator
- Mariabackup实现Mariadb 10.3的增量数据备份
- INeuOS industrial Internet operating system, the equipment operational business and "low code" form development tools
- 【NLP】什么是模型的记忆力!
- 动态规划之线性dp(上)
猜你喜欢
All-platform GPU general AI video supplementary frame super-score tutorial
20.支持向量机—数学原理知识
这位985教授火了!当了10年博导,竟无一博士毕业!
Huawei mobile phone one-click to open "maintenance mode" to hide all data and make mobile phone privacy more secure
Go basic part study notes
go mode tidy出现报错go warning “all“ matched no packages
上传图片-微信小程序(那些年的坑记录2022.4)
Jiuqi ny3p series voice chip replaces the domestic solution KT148A, which is more cost-effective and has a length of 420 seconds
全平台GPU通用AI视频补帧超分教程
A common method and the use of selenium
随机推荐
动态规划(一)
Handling Write Conflicts under Multi-Master Replication (1)-Synchronous and Asynchronous Conflict Detection and Conflict Avoidance
Last write wins (discards concurrent writes)
mysql的备份表的几种方法
After Effects tutorial, How to adjust overexposed snapshots in After Effects?
你辛辛苦苦写的文章可能不是你的原创
【luogu P8326】Fliper (Graph Theory) (Construction) (Eulerian Circuit)
MySQL - single function
Flutter set the background color of the statusbar status bar and APP method (AppBar) internal consistent color.
Automated testing - web automation - first acquaintance with selenium
UserAgent 解析
ThreadLocal
Huawei mobile phone one-click to open "maintenance mode" to hide all data and make mobile phone privacy more secure
[Source code analysis] BeanFactory and FactoryBean
【luogu P8326】Fliper(图论)(构造)(欧拉回路)
MATLAB程序设计与应用 2.4 MATLAB常用内部函数
API for JD.com to obtain historical price information of commodities
go基础部分学习笔记记录
Bika LIMS 开源LIMS集—— SENAITE的使用(检测流程)
adb shell 报错error: device unauthorized