当前位置:网站首页>ResNet的基础:残差块的原理
ResNet的基础:残差块的原理
2022-07-31 17:31:00 【GIS与Climate】
在深度学习中,为了增强模型的学习能力,网络层会变得越来越深,但是随着深度的增加,也带来了比较一些问题,主要包括:
模型复杂度上升,网络训练困难; 梯度消失/梯度爆炸 网络退化,也就是说模型的学习能力达到了饱和,增加网络层数并不能提升精度了。
为了解决网络退化问题,何凯明大佬提出了深度残差网络,可以说是深度学习中一个非常大的创造性工作。
残差网络
残差网络的思想就是将网络学习的映射从X到Y转为学习从X到Y-X的差,然后把学习到的残差信息加到原来的输出上即可。即便在某些极端情况下,这个残差为0,那么网络就是一个X到Y的恒等映射。其示意图如下:

上图的结构中其实主线与正常的网络结构没什么区别,差异在于右边的连接线,作者称之为Shortcut Connection,意思就是跳过了一些网络层直接与后面的某一个层的输出结果进行连接。
优势
残差网络中,因为残差块保留了原始输入的信息,所以网络有如下优势:
随着深度的增加,可以获取更高的精度,因为其学习的残差越准确; 网络优化比较简单; 比较通用;
残差块的实现
按照上面的图所示的结构,在Pytorch中实现一个残差块也非常简单,无非就是在传统的网络中加上一个shortcut connection,比如一个最基础的残差块代码如下:
class ResidualBlock(nn.Module):
def __init__(self, channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.prelu = nn.PReLU()
self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.prelu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.prelu(x + residual)
return out
通过上面的代码就实现了一个最最基础的残差块(仅仅是按图实现的,跟原文里面的不太一样)。需要注意的地方有:
残差块因为在forward函数的最后需要将输入x和学习到的残差(也就是 )相加,所以这两个张量的尺寸应该是完全一致的; 在最后将 相加之后再输入激活函数; 每一个卷积层后要跟上一个批归一化层。
在真正用的时候,上面的代码还需要再进行复杂化,比如是否需要对数据进行下采样等,不过看懂了上面的基础,就可以自己进行相应的修改,来适用于自己的网络。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
边栏推荐
- 宁波大学NBU IT项目管理期末考试知识点整理
- Golang go-redis cluster模式下不断创建新连接,效率下降问题解决
- MySQL---Create and manage databases and data tables
- 【愚公系列】2022年07月 Go教学课程 020-Go容器之数组
- 浅谈网络安全之算法安全
- MySQL---operator
- 牛客 HJ3 明明的随机数
- Kotlin coroutines: continuation, continuation interceptor, scheduler
- Flutter 获取状态栏statusbar的高度
- AcWing 1282. 搜索关键词 题解((AC自动机)Trie+KMP)+bfs)
猜你喜欢
随机推荐
联邦学习:联邦场景下的多源知识图谱嵌入
九齐ny3p系列语音芯片替代国产方案KT148A性价比更高420秒长度
IP protocol from 0 to 1
35道MySQL面试必问题图解,这样也太好理解了吧
Smart Trash Can (8) - Infrared Tube Sensor (Raspberry Pi pico)
After Effects tutorial, How to adjust overexposed snapshots in After Effects?
全平台GPU通用AI视频补帧超分教程
TestCafe之如何进行调试
IP协议从0到1
Bika LIMS 开源LIMS集—— SENAITE的使用(检测流程)
flowable工作流所有业务概念
Go1.18升级功能 - 模糊测试Fuzz 从零开始Go语言
[TypeScript] OOP
10 Ways to Keep Your Interface Data Safe
Concurrency, Timing and Relativity
useragent怎么获取
Huawei mobile phone one-click to open "maintenance mode" to hide all data and make mobile phone privacy more secure
宁波大学NBU IT项目管理期末考试知识点整理
Go record - slice
你辛辛苦苦写的文章可能不是你的原创









