当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-07-31 17:31:00 【GIS与Climate】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
边栏推荐
- 【网络通信三】研华网关Modbus服务设置
- MySQL---聚合函数
- Bika LIMS open source LIMS set - use of SENAITE (detection process)
- [TypeScript]OOP
- The article you worked so hard to write may not be your original
- Golang 切片删除指定元素的几种方法
- 【Yugong Series】July 2022 Go Teaching Course 021-Slicing Operation of Go Containers
- flutter设置statusbar状态栏的背景颜色和 APP(AppBar)内部颜色一致方法。
- Kotlin协程:续体、续体拦截器、调度器
- 基于WPF重复造轮子,写一款数据库文档管理工具(一)
猜你喜欢

动态规划(一)

使用互相关进行音频对齐

Bika LIMS 开源LIMS集—— SENAITE的使用(检测流程)

杰理语音芯片ic玩具芯片ic的介绍_AD14NAD15N全系列开发

Jiuqi ny3p series voice chip replaces the domestic solution KT148A, which is more cost-effective and has a length of 420 seconds
![[pytorch] pytorch automatic derivation, Tensor and Autograd](/img/99/c9632a7d3f70a13e1e26b9aa67b8b9.png)
[pytorch] pytorch automatic derivation, Tensor and Autograd
![[TypeScript] OOP](/img/d7/b3175ab538906ac1b658a9f361ba44.png)
[TypeScript] OOP

智能垃圾桶(九)——震动传感器(树莓派pico实现)

Intelligent bin (9) - vibration sensor (raspberries pie pico implementation)

新型电信“套路”,我爸中招了!
随机推荐
20.支持向量机—数学原理知识
21.支持向量机—核函数的介绍
[Network Communication 3] Advantech Gateway Modbus Service Settings
多数据中心操作和检测并发写入
Golang 切片删除指定元素的几种方法
如何识别假爬虫?
【愚公系列】2022年07月 Go教学课程 022-Go容器之字典
Golang 必知必会Go Mod命令
你辛辛苦苦写的文章可能不是你的原创
go基础部分学习笔记记录
useragent在线查找
Go basic part study notes
研发过程中的文档管理与工具
How to install CV2 smoothly in Anaconda
几款永久免费内网穿透,好用且简单(内网穿透教程)
flyway的快速入门教程
Intelligent bin (9) - vibration sensor (raspberries pie pico implementation)
动态规划之线性dp(上)
基于WPF重复造轮子,写一款数据库文档管理工具(一)
全平台GPU通用AI视频补帧超分教程