当前位置:网站首页>来亲自手搭一个ResNet18网络
来亲自手搭一个ResNet18网络
2022-08-02 17:33:00 【GIS与Climate】

在何大佬的文章中,提出了下面两种残差块:

左边的称为building block,右边的称为bottleneck building block。
左边的输入和输出都是64个channel的,四四方方的,像个建筑物; 右边的就好像通过了一个瓶颈一样,输入残差块的网络通道数会先从256变成64,然后最终再升到256,其中降维和升维使用的是1x1的卷积,可以减少参数量;
代码
Building Block,就是之前写的(换了激活函数):
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.relu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.relu(x + residual)
return out
BottleNeck的代码稍微改下就行:
class BottleNeck(nn.Module):
def __init__(self,in_channels):
super(BottleNeck, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(in_channels,64,kernel_size=1,stride=1,padding=0),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64,64,kernel_size=3,stride=1,padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(63,in_channels,kernel_size=1,stride=1,padding=0),
nn.BatchNorm2d(in_channels),
)
self.shortcut = nn.Sequential()
def forward(self,x):
shortcut = self.shortcut(x)
residual = self.main(x)
out = nn.ReLU(shortcut + residual )
return out
几点经验
在网络中,如果层数比较多的时候尽可能使用容器来写(比如上面的Sequential),这样子看起来更加的清晰; 在残差块的最后加上输入之后记得要加上激活函数; 要有 积木思想,就是尽可能的把网络中的结构搭建为可复用的 块,就比如上面的残差块; BottleNeck用在较深的网络层中可以减少参数量;
ResNet18
在何大佬的文章中提出了几种不同的残差网络,主要是网络层的不同,最少的为18层:

我们根据上面的信息复现一下ResNet18。先分析其结构:
原文中用的图像输入是3*224*224,先通过一个7*7*64的卷积,但是步长设置为2,使得图像的大小缩小了一半; 在con2_x的刚开始,通过一个最大值池化,步长设置为2,使得图像又缩小了一半; 然后是con2_x、con3_x、con4_x、con5_x一共8个残差块; 按照作者说的,在con3_1、con4_1、con5_1都进行了2倍的下采样; 最后一层先经过一个自适应平均池化层,然后一个全连接层映射到输出;
那么根据上面的过程写代码即可,但是写之前有几点需要注意:
原文章中说了每个卷积层后面跟上一个批归一化层(BN层); 特征图的尺寸减半的时候,特征图的数量要增加一倍; 原文说的是直接用的步长为2的卷积层进行下采样;
整体代码:
class ResidualBlock(nn.Module):
def __init__(self, in_channels):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels)
self.relu = nn.ReLU()
self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels)
def forward(self, x):
residual = self.conv1(x)
residual = self.bn1(residual)
residual = self.relu(residual)
residual = self.conv2(residual)
residual = self.bn2(residual)
out = self.relu(x + residual)
return out
class ResNet18(nn.Module):
def __init__(self,in_channels,resblock,outputs=1000):
super(ResNet18, self).__init__()
self.block1 = nn.Sequential(
nn.Conv2d(in_channels,64,kernel_size=7,stride=2,padding=3),
nn.BatchNorm2d(64),
nn.ReLU(),
)
self.block2 = nn.Sequential(
nn.MaxPool2d(kernel_size=3,stride=2,padding=1),
resblock(in_channels=64),
resblock(in_channels=64)
)
self.block3 = nn.Sequential(
nn.Conv2d(64,128,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=128),
resblock(in_channels=128),
)
self.block4 = nn.Sequential(
nn.Conv2d(128,256,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=256),
resblock(in_channels=256),
)
self.block5 = nn.Sequential(
nn.Conv2d(256,512,kernel_size=3,stride=(2,2),padding=1),
resblock(in_channels=512),
resblock(in_channels=512),
)
self.block6 = nn.AdaptiveAvgPool2d(output_size=(1,1))
self.fc = nn.Linear(in_features=512,out_features=outputs)
def forward(self, x):
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = x.reshape(x.shape[0],-1)
x = self.fc(x)
return x
为什么叫做ResNet18?
如果打印出来看下(用之前说的torchsummary),可以发现其中带有可学习参数的层数一共是18层,所以叫做ResNet18(除去那些BN层、激活函数层等)。
小总结与注意项
上面的代码我是严格按照参考【1】的论文进行复现的,可能跟网上的有些不一样,比如跟参考【3】的就不太一样; 网上各种ResNet18的复现,但是也有一些跟论文不太一样的地方,比如参考【4】在最初的卷积层之后就没有加激活层; 还有其他的一些网上的教程也是各不相同,但是都是大同小异,所以在使用的时候要自己仔细斟酌; 上面的复现跟Pytorch官方提供的基本一致,但是参数量有些不同,后面就没有细细比较了; 使用别人的代码的时候一定要先读懂了原理,不要无脑直接套用,血的教训; 酌情根据自己的需求修改其中可以修改的模块(比如激活函数,卷积核的大小等)。
参考
【1】HE K, ZHANG X, REN S, et al. Deep Residual Learning for Image Recognition[C]//2016 IEEE Conference on Computer Vision and Pattern Recognition (CVPR).2016:770-778. 10.1109/CVPR.2016.90.
【2】https://blog.csdn.net/sazass/article/details/116864275
【3】https://towardsdev.com/implement-resnet-with-pytorch-a9fb40a77448
【4】https://blog.csdn.net/weixin_36979214/article/details/108879684?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522162374909216780265420718%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=162374909216780265420718&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~all~sobaiduend~default-1-108879684.first_rank_v2_pc_rank_v29&utm_term=pytorch+resnet18&spm=1018.2226.3001.4187
边栏推荐
猜你喜欢
0725-面试记录
研发运营一体化(DevOps)能力成熟度模型
什么是实时流引擎?
Navicat premium download and install 15 detailed tutorial
小程序毕设作品之微信体育馆预约小程序毕业设计成品(8)毕业设计论文模板
Real-time data warehouse architecture evolution and selection
土巴兔IPO五次折戟,互联网家装未解“中介”之痛
How Tencent architects explained: The principle of Redis high-performance communication (essential version)
Simulink脚本自动创建Autosar Parameter Port及Mapping
Security First: Tools You Need to Know to Implement DevSecOps Best Practices
随机推荐
Smart Contract Security - delegatecall (1)
redis总结_多级缓存
融云「 IM 进阶实战高手课」系列直播上线
golang源码分析(5):sync.Once
百问百答第49期:极客有约——国内可观测领域SaaS产品的发展前景
罗敏背后是抖音
0725-面试记录
阿波罗 planning代码-modules\planning\lattice\trajectory_generation\PiecewiseBrakingTrajectoryGenerator类详解
2021年下半年软件设计师上午真题
今年上半年,我国公路建设总体形势持续向好
边界访问的空间权限
脉脉上的相亲生意
golang源码分析(7):chan
MySQL常见函数
恒驰5真的没大卖
Flink Learning 9: Configure the idea to develop the flink-Scala program environment
SQL 正则解析手机号码提供商
发挥云网融合优势,天翼云为政企铺设数字化转型跑道
一些与开发者体验有关的话题
Gear 月度更新|6 月