当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-07-31 17:31:00 【GIS与Climate】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
边栏推荐
- 牛客网刷题(一)
- 【Yugong Series】July 2022 Go Teaching Course 020-Array of Go Containers
- MySQL - single function
- 组合学笔记(六)局部有限偏序集的关联代数,Möbius反演公式
- Go record - slice
- Three aspects of Ali: How to solve the problem of MQ message loss, duplication and backlog?
- 每日练习------随机产生一个1-100之间的整数,看能几次猜中。要求:猜的次数不能超过7次,每次猜完之后都要提示“大了”或者“小了”。
- MySQL---基本的select语句
- Intelligent bin (9) - vibration sensor (raspberries pie pico implementation)
- 九齐ny3p系列语音芯片替代国产方案KT148A性价比更高420秒长度
猜你喜欢

After Effects 教程,如何在 After Effects 中调整过度曝光的快照?

【pytorch】1.7 pytorch与numpy,tensor与array的转换

使用互相关进行音频对齐

IP协议从0到1

flyway的快速入门教程

阿里三面:MQ 消息丢失、重复、积压问题,如何解决?
![[Network Communication 3] Advantech Gateway Modbus Service Settings](/img/ec/e9e1d9a374183ecaa8a8c9437ec82c.png)
[Network Communication 3] Advantech Gateway Modbus Service Settings

学生管理系统第一天:完成登录退出操作逻辑 PyQt5 + MySQL5.8

Automated testing - web automation - first acquaintance with selenium

基于WPF重复造轮子,写一款数据库文档管理工具(一)
随机推荐
Bika LIMS 开源LIMS集—— SENAITE的使用(检测流程)
BOW/DOM(上)
Verilog实现占空比为5/18的9分频
如何识别假爬虫?
MySQL common statements
京东按关键字搜索商品 API
adb shell error error: device unauthorized
Write a database document management tool based on WPF repeating the wheel (1)
The article you worked so hard to write may not be your original
常用的安全渗透测试工具(渗透测试工具)
你辛辛苦苦写的文章可能不是你的原创
Handling Write Conflicts under Multi-Master Replication (1)-Synchronous and Asynchronous Conflict Detection and Conflict Avoidance
中文编码的设置与action方法的返回值
This 985 professor is on fire!After 10 years of Ph.D. supervisor, no one has graduated with a Ph.D.!
useragent怎么获取
go记录之——slice
MySQL - single function
Automated testing - web automation - first acquaintance with selenium
【码蹄集新手村600题】通向公式与程序相结合
After Effects tutorial, How to adjust overexposed snapshots in After Effects?