当前位置:网站首页>来亲自手搭一个ResNet18网络
来亲自手搭一个ResNet18网络
2022-08-02 17:33:00 【GIS与Climate】
在何大佬的文章中,提出了下面两种残差块:
左边的称为building block,右边的称为bottleneck building block。
左边的输入和输出都是64个channel的,四四方方的,像个建筑物; 右边的就好像通过了一个瓶颈一样,输入残差块的网络通道数会先从256变成64,然后最终再升到256,其中降维和升维使用的是1x1的卷积,可以减少参数量;
代码
Building Block,就是之前写的(换了激活函数):
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.relu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.relu(x + residual)
return out
BottleNeck的代码稍微改下就行:
class BottleNeck(nn.Module):
def __init__(self,in_channels):
super(BottleNeck, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels,64,kernel_size=1,stride=1,padding=0),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(63,in_channels,kernel_size=1,stride=1,padding=0),
nn.BatchNorm2d(in_channels),
)
self.shortcut = nn.Sequential()
def forward(self,x):
shortcut = self.shortcut(x)
residual = self.main(x)
out = nn.ReLU(shortcut + residual )
return out
几点经验
在网络中,如果层数比较多的时候尽可能使用容器来写(比如上面的Sequential),这样子看起来更加的清晰; 在残差块的最后加上输入之后记得要加上激活函数; 要有 积木思想,就是尽可能的把网络中的结构搭建为可复用的 块,就比如上面的残差块; BottleNeck用在较深的网络层中可以减少参数量;
ResNet18
在何大佬的文章中提出了几种不同的残差网络,主要是网络层的不同,最少的为18层:
我们根据上面的信息复现一下ResNet18。先分析其结构:
原文中用的图像输入是3*224*224,先通过一个7*7*64的卷积,但是步长设置为2,使得图像的大小缩小了一半; 在con2_x的刚开始,通过一个最大值池化,步长设置为2,使得图像又缩小了一半; 然后是con2_x、con3_x、con4_x、con5_x一共8个残差块; 按照作者说的,在con3_1、con4_1、con5_1都进行了2倍的下采样; 最后一层先经过一个自适应平均池化层,然后一个全连接层映射到输出;
那么根据上面的过程写代码即可,但是写之前有几点需要注意:
原文章中说了每个卷积层后面跟上一个批归一化层(BN层); 特征图的尺寸减半的时候,特征图的数量要增加一倍; 原文说的是直接用的步长为2的卷积层进行下采样;
整体代码:
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.relu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.relu(x + residual)
return out
class ResNet18(nn.Module):
def __init__(self,in_channels,resblock,outputs=1000):
super(ResNet18, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.block2 = nn.Sequential(
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
resblock(in_channels=64),
resblock(in_channels=64)
)
self.block3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=128),
resblock(in_channels=128),
)
self.block4 = nn.Sequential(
nn.Conv2d(128,256,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=256),
resblock(in_channels=256),
)
self.block5 = nn.Sequential(
nn.Conv2d(256,512,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=512),
resblock(in_channels=512),
)
self.block6 = nn.AdaptiveAvgPool2d(output_size=(1,1))
self.fc = nn.Linear(in_features=512,out_features=outputs)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = x.reshape(x.shape[0],-1)
x = self.fc(x)
return x
为什么叫做ResNet18?
如果打印出来看下(用之前说的torchsummary),可以发现其中带有可学习参数的层数一共是18层,所以叫做ResNet18(除去那些BN层、激活函数层等)。
小总结与注意项
上面的代码我是严格按照参考【1】的论文进行复现的,可能跟网上的有些不一样,比如跟参考【3】的就不太一样; 网上各种ResNet18的复现,但是也有一些跟论文不太一样的地方,比如参考【4】在最初的卷积层之后就没有加激活层; 还有其他的一些网上的教程也是各不相同,但是都是大同小异,所以在使用的时候要自己仔细斟酌; 上面的复现跟Pytorch官方提供的基本一致,但是参数量有些不同,后面就没有细细比较了; 使用别人的代码的时候一定要先读懂了原理,不要无脑直接套用,血的教训; 酌情根据自己的需求修改其中可以修改的模块(比如激活函数,卷积核的大小等)。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://blog.csdn.net/sazass/article/details/116864275
【3】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
【4】https://blog.csdn.net/weixin_36979214/article/details/108879684?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162374909216780265420718%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162374909216780265420718&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-108879684.first_rank_v2_pc_rank_v29&utm_term=pytorch+resnet18&spm=1018.2226.3001.4187
边栏推荐
猜你喜欢
What is the difference between erp system and wms system
故障分析 | 一条 SELECT 语句跑崩了 MySQL ,怎么回事?
AI+医疗:使用神经网络进行医学影像识别分析
vulnhub W34kn3ss: 1
The days of patching are more difficult than the days of writing code
NoSQL之redis缓存雪崩、穿透、击穿概念解决办法
NAACL 2022 | 具有元重加权的鲁棒自增强命名实体识别技术
Five speakers: seventy genius_platform software platform development 】 【 turn YUY2 RGB24 implementation source code
Gear 月度更新|6 月
Navicat for mysql cracked versions installed
随机推荐
一篇文章带你搞定BFC~
Continuous integration (4) Jenkins configuration alarm mechanism
再获权威认证!马上消费安逸花APP通过中国信通院“金融APP人脸识别安全能力评测”
动力电池扩产潮,宁德时代遭围剿
Gartner released, annual Challenger!
Flink学习9:配置idea开发flink-Scala程序环境
深圳地铁16号线二期进入盾构施工阶段,首台盾构机顺利始发
蔚来杯2022牛客暑期多校训练营5 ABCDFGHK
What is an APS system?What should I pay attention to when importing APS?Worth watching again and again
玩转云端 | 天翼云对象存储ZOS高可用的关键技术揭秘
MySQL表的约束
【案例】2D变换-旋转动画
MySQL常用的日期时间函数
什么是实时流引擎?
腾讯架构师是如何解释:Redis高性能通信的原理(精华版)
安全至上:落地DevSecOps最佳实践你不得不知道的工具
npm install报gyp info it worked if it ends with ok
DeepMind 首席科学家 Oriol Vinyals 最新访谈:通用 AI 的未来是强交互式元学习
小程序毕设作品之微信体育馆预约小程序毕业设计成品(5)任务书
9月起中国给予多哥等16国98%税目产品零关税待遇