当前位置:网站首页>动手学深度学习_NiN
动手学深度学习_NiN
2022-08-04 21:00:00 【CV小Rookie】
LeNet 、AlexNet 和 VGG 都有一个共同的设计模式:通过一系列的卷积层与汇聚层来提取空间结构特征;然后通过全连接层对特征的表征进行处理。 AlexNet 和 VGG 对 LeNet 的改进主要在于如何扩大和加深这两个模块。
然而,如果使用了全连接层,可能会完全放弃表征的空间结构。 网络中的网络(NiN)提供了一个非常简单的解决方案:在每个像素的通道上分别使用多层感知机(其实就是加两层 1 x 1 的卷积,因为前面说过,1 x 1 的卷积相当于参数共享的 MLP)
通过图解可以看到,NiN 网络就是由 nin_block 组成,一个 nin_block 由一个卷积层 + 两个
1 x 1卷积组成:
最后的输出取消使用 MLP ,而是使用一个全局的 Pooling 将特征图的高和宽变为1,最后使用 Flatten 展平,得到输出。
def nin_block(in_channels, out_channels, kernel_size, strides, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, strides, padding),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1), nn.ReLU())
class NiN(nn.Module):
def __init__(self):
super(NiN, self).__init__()
self.model =nn.Sequential(
nin_block(1, 96, kernel_size=11, strides=4, padding=0),
nn.MaxPool2d(3, stride=2),
nin_block(96, 256, kernel_size=5, strides=1, padding=2),
nn.MaxPool2d(3, stride=2),
nin_block(256, 384, kernel_size=3, strides=1, padding=1),
nn.MaxPool2d(3, stride=2),
nn.Dropout(0.5),
# 标签类别数是10
nin_block(384, 10, kernel_size=3, strides=1, padding=1),
nn.AdaptiveAvgPool2d((1, 1)),
# 将四维的输出转成二维的输出,其形状为(批量大小,10)
nn.Flatten()
)
def forward(self,x):
x = self.model(x)
return x
每一层的输出的 size :
Sequential output shape: torch.Size([1, 96, 54, 54]) MaxPool2d output shape: torch.Size([1, 96, 26, 26]) Sequential output shape: torch.Size([1, 256, 26, 26]) MaxPool2d output shape: torch.Size([1, 256, 12, 12]) Sequential output shape: torch.Size([1, 384, 12, 12]) MaxPool2d output shape: torch.Size([1, 384, 5, 5]) Dropout output shape: torch.Size([1, 384, 5, 5]) Sequential output shape: torch.Size([1, 10, 5, 5]) AdaptiveAvgPool2d output shape: torch.Size([1, 10, 1, 1]) Flatten output shape: torch.Size([1, 10])
边栏推荐
猜你喜欢
随机推荐
Using Baidu EasyDL to realize forest fire early warning and identification
使用百度EasyDL实现森林火灾预警识别
Zero-knowledge proof notes - private transaction, pederson, interval proof, proof of ownership
嵌入式分享合集28
js数据类型、节流/防抖、点击事件委派优化、过渡动画
Oreo domain name authorization verification system v1.0.6 public open source version website source code
How to make good use of builder mode
暴雨中的人
adb控制常用命令
Red5搭建直播平台
深度解析:为什么跨链桥又双叒出事了?
How to carry out AI business diagnosis and quickly identify growth points for cost reduction and efficiency improvement?
jekyll 在博客添加流程图
[TypeScript] In-depth study of TypeScript enumeration
如何最简单、通俗地理解爬虫的Scrapy框架?
拒绝服务攻击DDoS介绍与防范
构建Buildroot根文件系统(I.MX6ULL)
Cryptography Series: PEM and PKCS7, PKCS8, PKCS12
Uniapp微信雪糕刺客单页小程序源码
Common methods of js's new Function()