当前位置:网站首页>ResNet论文解读及代码实现(pytorch)
ResNet论文解读及代码实现(pytorch)
2022-07-26 22:37:00 【Ap21ril】
又重新看了一遍何凯明大神的残差网络,之前懵懵懂懂的知识豁然开朗了起来。然后,虽然现在CSDN和知乎的风气不是太好,都是一些复制粘贴别人的作品来给自己的博客提高阅读量的人,但是也可以从其中汲取到很多有用的知识,我们要取其精华,弃其糟粕。
我只是大概的记录一下ResNet论文讲了什么,希望大家还是可以自己去读几遍。
ResNet论文链接为:https://arxiv.org/abs/1512.03385
1.前言
在读这篇文章之前,希望可以思考一个问题。残差网络到底是用来干什么的?我想很多人思考过后后的回答就是“残差网络不就是解决过深的网络引起的梯度消失和梯度爆炸这种现象嘛。”
这个回答是没问题的,但是梯度消失和梯度爆炸可以通过归一化初始化或中间层归一化来解决,还有最重要的一个原因就是过深的网络会出现网络退化的问题。何为网络退化?这里拿出原论文中的一个插图来解释。
按照正常的逻辑来说,神经网络越深训练的效果应该会越好啊,但实验推翻了我们这个结论。
假设一个比较浅的网络已经可以达到不错的效果,那么即使之后堆上去的网络什么也不做,模型的效果也不会变差。
问题就出现在这里,什么都不做就是神经网络最难做到的东西。
2. 论文解读
1. 摘要
在摘要中,作者主要是提出了一种残差结构,这种结构不需要额外的参数也更容易被优化,结果证明加了这种结构后不管是在ImageNet分类数据集上还是在COCO目标检测数据集上都有很好的效果,啊,不对,不是很好,是在大赛中都取得了第一名。
2. 引言
第一段主要介绍了一些研究背景,深度卷积网络为图像分类带来了一系列的突破。网络深度对训练模型来说也是至关重要的。
作者在第二段就表明了,随着深度的增加会出现梯度消失、梯度爆炸,但是这个问题已经在很大程度上通过归一化初始化和中间归一化层得到了解决。
第三段提出:随着网络深度的增加,精度达到饱和(这可能并不奇怪),然后迅速退化。然而,这种退化并不是由过拟合引起的,在适当深度的模型上添加更多的层会导致更高的训练误差。
第四段提出了一个名词identity mapping,我理解为输入为X,输出也为X。构造深的模型,虽然有解决方案,但是在可行时间内是不可能实现的。
第五段提出了残差网络的结构
X 为浅层网络的输出。如果我们想要得到的映射为H(X),则我们让添加的非线性网络层去拟合残差映射F(X):=H(X)-X。原始的映射就可以写成F(X)+X。图二中右侧连接为一个identity mapping。
第六段提出了shortcut connection的概念。这种连接可以跳过一层或更多层,因为F(X)和X是直接相加的。所以不需要额外的参数和更复杂的计算。
剩余部分都是在说这个模型有多好,在100层,甚至1000层的模型上效果都很好。
我们假设第一个网络在训练集和测试集上可以得到很好的性能(甚至可以理解为接近100%)。
那么在这个新的网络,由于我们copy了前四层的参数,理论上前四层已经足够满足我们的性能要求,那么新增加的层便显得有些多余,如果这个新的网络也要达到性能100%,则新增加的层要做的事情就是“恒等映射”,也即后面几个紫色的层要实现的效果为 
。这样一来,网络的性能一样能达到100%。而退化现象也表明了,实际上新增加的几个紫色的层,很难做到恒等映射。又或者能做到,但在有限的时间内很难完成(即网络要用指数级别的时间才能达到收敛)。这时候,巧妙的通过添加”桥梁“,使得难以优化的问题瞬间迎刃而解。
可以看到通过添加这个桥梁,把数据原封不动得送到FC层的前面,而对于中间的紫色层,可以很容易的通过把这些层的参数逼近于0,进而实现
的功能。
实际上,网络性能通常未能达到100%,可以假设最初的网络(只有前四层)的性能到了98%等等,如果不添加跳连接,增加三个紫色层之后的新网络同样难以进行优化(由上面极端情况的推广,也即前面四层的性能达到100%)。
而通过跳连接,可以把前四层的输出先送到FC层前面,也就相当于告诉紫色层:”兄弟你放心,我已经做完98%的工作了,你看看能不能在剩下的2%中发点力,你要是找不出提升性能的效果也没事的,我们可以把你的参数逼近于0,所以放心大胆的找吧。"
我们把整个映射看成100%,则前面四层网络实现了98%的映射关系,而残余的映射由紫色层完成,Residual 另一个翻译就是"残余,残留“的意思,也就是让每一个残差块,只关注残余映射的一小部分,真的是恰到好处。
当然了,实际上网络运行的时候,我们并不会知道哪几层就能达到很好的效果,然后在它们的后面接一个跳连接,于是一开始便在两个层或者三个层之间添加跳连接,形成残差块,每个残差块只关注当前的残余映射,而不会关注前面已经实现的底层映射。
不得不佩服他对神经网络的深入理解,从他灵感的来源,让我感觉他就是个数学大佬,结果一查还真是,本科是清华基础科学班的(研究物理数学的)(拿烟的手微微颤抖)。好了,废话不多说,让我们一起来理解什么是残差网络。我们先来看一个现象,假设我们有如下的一个网络,它可以在训练集和测试
这个例子来自于知乎大佬Sakura当时我看到后真的是茅塞顿开。
作者:Sakura
链接:https://www.zhihu.com/question/306135761/answer/2491142607
来源:知乎
3. 相关工作
相关工作是说在这篇论文之前已经有一些人在某些方向上使用本论文中使用的一些方法去做事。这句话很拗口,但是只要你理解了,就会发现所有论文的相关方向都是说的这东西。
4. 深度残差网络
这部分就是介绍了残差模块的构成还有各种ResNet的网络结构。
在上文中,我们说道F(X)和X直接相加,因此需要保证他们的维度一定要一样,否则就对X做投影。
关于残差网络结构(以34层为例),只截取了一部分
图右中的实线表示残差连接,虚线表示升维。
架构图:
更深的ResNet网络:
50层,101层,152层都使用图右中的结构,使用1X1的卷积主要是为了降维和升维。右边结构的参数量明显比左边的少。
总结
网络亮点:
- 超深的网络结构(突破1000层)
- 提出residual模块
- 使用Batch Normalization加速训练(丢弃dropout)
沐神很形象的描述出了为什么加一个残差连接就可以收敛。

代码实现
只是model模块
import torch.nn as nn
import torch
''' 对应18层,34层的残差结构 '''
class BasicBlock(nn.Module):
expansion = 1 #判断每一个卷积块中,卷积核的个数会不会有变化
def __init__(self, in_channel, out_channel, stride=1, downsample=None, **kwargs): # downsample表示是否有升维操作
super(BasicBlock, self).__init__()
# output = (input - kernel_size + 2*padding)/stride + 1
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=3, stride=stride, padding=1, bias=False) # stride=1表示option A;stride=2表示optionB 使用BN不需要偏置bias
self.bn1 = nn.BatchNorm2d(out_channel)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channel)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
''' 50层,101层,152层 '''
class Bottleneck(nn.Module):
""" 注意:原论文中,在虚线残差结构的主分支上,第一个1x1卷积层的步距是2,第二个3x3卷积层步距是1。 但在pytorch官方实现过程中是第一个1x1卷积层的步距是1,第二个3x3卷积层步距是2, 这么做的好处是能够在top1上提升大概0.5%的准确率。 可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch """
expansion = 4
def __init__(self, in_channel, out_channel, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=out_channel,
kernel_size=1, stride=1, bias=False) # squeeze channels
self.bn1 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv2 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel,
kernel_size=3, stride=stride, bias=False, padding=1)
self.bn2 = nn.BatchNorm2d(out_channel)
# -----------------------------------------
self.conv3 = nn.Conv2d(in_channels=out_channel, out_channels=out_channel*self.expansion,
kernel_size=1, stride=1, bias=False) # unsqueeze channels
self.bn3 = nn.BatchNorm2d(out_channel*self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
if self.downsample is not None:
identity = self.downsample(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self,
block, # 残差结构
blocks_num,
num_classes=1000,
include_top=True,
groups=1,
width_per_group=64):
super(ResNet, self).__init__()
self.include_top = include_top
self.in_channel = 64
self.groups = groups
self.width_per_group = width_per_group
self.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,
padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channel)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, blocks_num[0]) # 对应结构图中conv2_x,下面同理
self.layer2 = self._make_layer(block, 128, blocks_num[1], stride=2)
self.layer3 = self._make_layer(block, 256, blocks_num[2], stride=2)
self.layer4 = self._make_layer(block, 512, blocks_num[3], stride=2)
if self.include_top:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # output size = (1, 1)
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
''' block: BasicBlock或Bottleneck channel: 残差结构中的卷积核个数 block_num:这一层有多少残差结构,例:34的第一层有三个,第二层有四个 '''
def _make_layer(self, block, channel, block_num, stride=1):
downsample = None
# 快捷连接虚线部分
if stride != 1 or self.in_channel != channel * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(channel * block.expansion))
layers = []
# 搭建每一个conv的第一层
layers.append(block(self.in_channel,
channel,
downsample=downsample,
stride=stride,
groups=self.groups,
width_per_group=self.width_per_group))
self.in_channel = channel * block.expansion
for _ in range(1, block_num):
layers.append(block(self.in_channel,
channel,
groups=self.groups,
width_per_group=self.width_per_group))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.include_top:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def resnet34(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet34-333f7ec4.pth
return ResNet(BasicBlock, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet50(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet50-19c8e357.pth
return ResNet(Bottleneck, [3, 4, 6, 3], num_classes=num_classes, include_top=include_top)
def resnet101(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnet101-5d3b4d8f.pth
return ResNet(Bottleneck, [3, 4, 23, 3], num_classes=num_classes, include_top=include_top)
def resnext50_32x4d(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth
groups = 32
width_per_group = 4
return ResNet(Bottleneck, [3, 4, 6, 3],
num_classes=num_classes,
include_top=include_top,
groups=groups,
width_per_group=width_per_group)
def resnext101_32x8d(num_classes=1000, include_top=True):
# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth
groups = 32
width_per_group = 8
return ResNet(Bottleneck, [3, 4, 23, 3],
num_classes=num_classes,
include_top=include_top,
groups=groups,
width_per_group=width_per_group)
代码来自于B站大佬:https://space.bilibili.com/18161609
完整的代码和数据集在我的GitHub上。https://github.com/Glory-Peng/CV
边栏推荐
- Method of setting QQ to blank ID
- 04-传统的Synchronized锁
- Azure Synapse Analytics 性能优化指南(3)——使用具体化视图优化性能(下)
- 第1章 拦截器入门及使用技巧
- The attorney general and the director of the national security service of Ukraine were dismissed
- Familiarize you with the "phone book" of cloud network: DNS
- Share a regular expression
- 力扣141题:环形链表
- NFT display guide: how to display your NFT collection
- 第6节:cmake语法介绍
猜你喜欢

Chapter 1 Introduction and use skills of interceptors

Complete backpack and 01 Backpack

Practice of data storage scheme in distributed system
![[netding Cup 2018] Fakebook records](/img/9f/b9111da8b2d9f8e79d82847aec906c.png)
[netding Cup 2018] Fakebook records

数据库:MySQL基础+CRUD基本操作

Meeting OA my meeting

Question 141 of Li Kou: circular linked list

分页插件--PageHelper

Opencv camera calibration and distortion correction

【C语言】经典的递归问题
随机推荐
Mysql database complex operations: Database Constraints, query / connect table operations
Thousands of tiles' tilt model browsing speeds up, saying goodbye to the embarrassment of jumping out one by one
Apple TV HD with the first generation Siri remote is listed as obsolete
第1章 开发第一个restful应用
Paging plug-in -- PageHelper
查看 Anaconda 创建环境的位置
09_ Keyboard events
The attorney general and the director of the national security service of Ukraine were dismissed
Tree and binary tree (learning notes)
力扣141题:环形链表
Pytorch学习记录(二):张量
The basic operation of data tables in MySQL is very difficult. This experiment will take you through it from the beginning
MVC三层架构
证券公司哪家佣金最低?网上开户安全吗
简单的SQL优化
04-传统的Synchronized锁
Arthas quick start
Tencent cloud lightweight application server purchase method steps!
第1章 需求分析与ssm环境准备
Push to origin/master was rejected error resolution