当前位置:网站首页>MobileNetV1架构解析
MobileNetV1架构解析
2022-08-05 07:03:00 【别团等shy哥发育】
MobileNetV1架构解析
参考论文:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
1、简介
MobileNets基于一种流线型架构,使用深度可分离卷积构建轻量级深度神经网络。我们引入了两个简单的全局超参数,可以有效地在延迟和准确性之间进行权衡。这些超参数允许模型生成器根据问题的约束为其应用程序选择适当大小的模型。我们在资源和准确性权衡方面进行了大量实验,与其他流行的ImageNet分类模型相比,我们表现出了强大的性能。然后,我们展示了MobileNet在广泛的应用和用例中的有效性,包括目标检测、精细分类、人脸属性和大规模地理定位。
MobileNet模型可应用于各种识别任务,以实现高效的设备智能。

2、Model Architecture
2.1 Depthwise Separable Convolution
MobileNet模型基于深度可分离卷积,这是一种分解卷积的形式,将标准卷积分解为深度卷积和1*1的点卷积。对于MobileNet,深度卷积将单个滤波器应用于每个输入通道,然后,逐点卷积应用1*1卷积将输出与深度卷积相结合。
标准卷积在一个步骤中将输入滤波并组合成一组新的输出。深度可分离卷积将标准的卷积层分解为两层来做:
- 首先是各个通道单独做卷积运算,称之为Depthwise Convolution
- 然后用一个
1*1的标准卷积层进行各个通道间的合并,称之为Pointwise Convolution
论文中原图如下所示:

论文中将标准卷积(a)分级为深度卷积(b)和1*1逐点卷积的图如下:

2.2 Network Structure
MobileNet结构建立在深度可分离卷积上,第一层是全卷积。MobileNet架构如下表所示,除过最后的全连接层,所有层后面都是BatchNorm和ReLU非线性激活函数,最后的全连接层没有非线性,并馈送到softmax层进行分类。将深度卷积和点卷积计算为单独的层,MobileNet有28层。

3、传统卷积与深度可分离卷积图解
3.1 传统卷积

- 卷积核channel=输入特征矩阵channel
- 输出特征矩阵channel-卷积核个数
3.2 Depthwise卷积

- 卷积核channel=1
- 输入特征矩阵channel=卷积核个数=输出特征矩阵channel
DW卷积中的每一个卷积核只会和输入特征矩阵的一个channel进行卷积计算,所以输出的特征矩阵就等于输入的特征矩阵。
3.3 Pointwise卷积

Pointwise卷积和普通的卷积一样,只不过使用了1*1卷积核。
3.3 Depthwise Separable Convolution(深度可分离卷积)
深度可分离卷积由Depthwise卷积和Pointwise卷积两部分组成

4、Width Multiplier: Thinner Models
虽然基本MobileNet架构已经很小且延迟很低,但很多时候,特定用例或应用程序可能需要更小更快的模型。为了构造这些较小且计算成本较低的模型,我们引入了一个非常简单的参数α,称为宽度乘数。宽度倍增器α的作用是在每一层均匀地薄化网络。对于给定的层和宽度乘数α,输入通道数M变为αM,输出通道数N变为αN。
5、Resolution Multiplier:Reduced Representation
降低神经网络计算成本的第二个超参数是分辨率乘数ρ。我们将其应用于输入图像,然后通过相同的乘法器减少每个层的内部表示。
表3显示了当架构收缩方法顺序应用于层时,层的计算和参数数量。第一行显示了全卷积层的Mult加法和参数,输入特征图的大小为14×14×512,核K的大小为3×3×512×512。

6、模型计算量与精度之间的权衡
首先,我们展示了具有深度可分离卷积的MobileNet与使用全卷积构建的模型相比的结果。在表4中,我们看到,与全卷积相比,使用深度可分离卷积在ImageNet上仅降低了1%的准确度,并在多个加法和参数上节省了大量成本。

表5显示,在类似的计算和参数数量下,使MobileNets变薄比使其变浅要好3%。

表6显示了使用宽度乘数α缩小MobileNet架构的精度、计算和尺寸权衡。精度平稳下降,直到在α=0.25时架构变得太小。

表7显示了通过使用降低的输入分辨率训练MobileNet,不同分辨率乘法器的精度、计算和大小权衡。精度在整个分辨率范围内平稳下降。

7、模型搭建(Tensorflow2.0)
这里是手动搭建的,你也可以直接使用迁移学习相关的API。
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import DepthwiseConv2D,BatchNormalization,ReLU,Conv2D
from tensorflow.keras.layers import GlobalAvgPool2D,Dense
# 定义可分离卷积块结构
def depthwise_conv_block(inputs,pointwise_conv_filters,strides=(1,1)):
x=DepthwiseConv2D((3,3),padding='same',strides=strides,use_bias=False)(inputs)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
x=Conv2D(pointwise_conv_filters,kernel_size=(1,1),padding='same',use_bias=False)(x)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
return x
# 定义MobileNet第一层普通卷积结构
def conv_block(inputs,filters,kernel_size=(3,3),strides=(1,1)):
x=Conv2D(filters,kernel_size=kernel_size,strides=strides,padding='same',
use_bias=False)(inputs)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
return x
def mobilenet_v1(inputs,classes):
# 第一层普通卷积
#[32,32,3]=>[16,16,32]
x=conv_block(inputs,32,strides=(2,2))
# [16,16,32]=>[16,16,64]
x=depthwise_conv_block(x,64)
# [16,16,64]=>[8,8,128]
x=depthwise_conv_block(x,128,strides=(2,2))
# [8,8,128]=>[8,8,128]
x=depthwise_conv_block(x,128)
# [8,8,128]=>[4,4,256]
x=depthwise_conv_block(x,256,strides=(2,2))
# [4,4,256]=>[4,4,256]
x=depthwise_conv_block(x,256)
# [4,4,256]=>[2,2,512]
x=depthwise_conv_block(x,512,strides=(2,2))
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[1,1,1024]
x=depthwise_conv_block(x,1024,strides=(2,2))
# [1,1,1024]=>[1,1,1024]
x=depthwise_conv_block(x,1024)
# [1,1,1024]=>[1024,]
x=GlobalAvgPool2D()(x)
# [1024,]=>[classes,]
x=Dense(classes,activation='softmax')(x)
return x
这里顺便在CIFAR10数据集上面测试一下。
INPUT_WIDTH=32
INPUT_HEIGHT=32
N_CHANNELS=3
N_CLASSES=10
batch_size=128
epochs=10
inputs=tf.keras.Input(shape=(INPUT_WIDTH,INPUT_HEIGHT,N_CHANNELS))
outputs=mobilenet_v1(inputs,N_CLASSES)
model=tf.keras.Model(inputs=inputs,outputs=outputs)
model.summary()

# 数据准备
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()
# 归一化
x_train,x_test=x_train/255.0,x_test/255.0
# 转为独热编码
y_train=tf.keras.utils.to_categorical(y_train,N_CLASSES)
y_test=tf.keras.utils.to_categorical(y_test,N_CLASSES)
# 模型编译
adam=tf.keras.optimizers.Adam(1e-4)
model.compile(optimizers=adam,loss='categorical_crossentropy',
metrics=['accuracy'])
# 模型训练
history = model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
validation_freq=1
)

accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(accuracy) + 1)
fig,ax = plt.subplots(1,2,figsize=(12,6)) # figsize=(width, height)
ax[0].plot(epochs, accuracy, "bo", label="Training accuracy")
ax[0].plot(epochs, val_accuracy, "b", label="Validation accuracy")
ax[0].set_title("Training and validation accuracy")
ax[0].legend()
ax[1].plot(epochs, loss, "bo", label="Training loss")
ax[1].plot(epochs, val_loss, "b", label="Validation loss")
ax[1].set_title("Training and validation loss")
ax[1].legend()

References
Howard A G , Zhu M , Chen B , et al. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications[J]. 2017.
边栏推荐
- 风控特征的优化分箱,看看这样教科书的操作
- 对数据类型而言运算符无效。运算符为 add,类型为 text。
- 工作3年,回想刚入门和现在的今昔对比,笑谈一下自己的测试生涯
- MySQL: JDBC programming
- Day9 of Hegong Daqiong team vision team training - camera calibration
- TCP sticky packet unpacking problem + solution
- Mysql master-slave delay reasons and solutions
- 【工具配置篇】VSCode 常用使用总结
- 配合屏幕录像专家,又小又清晰!
- 【JVM调优】Xms和Xmx为什么要保持一致
猜你喜欢
随机推荐
图片地址转为base64
Technical Analysis Patterns (11) How to Trade Head and Shoulders Patterns
对数据类型而言运算符无效。运算符为 add,类型为 text。
2022熔化焊接与热切割操作证考试题及模拟考试
Falsely bamboo brother today and found a localization of API to use tools
binary search tree problem
Bluetooth gap protocol
TRACE32——Break
任务流调度工具AirFlow,,220804,,
In the anaconda Promat interface, import torch is passed, and the error is reported in the jupyter notebook (only provide ideas and understanding!)
Shiny02---Shiny异常解决
Promise (3) async/await
环网冗余式CAN/光纤转换器 CAN总线转光纤转换器中继集线器hub光端机
TRACE32——C源码关联1
[instancetype type Objective-C]
【win7】NtWaitForKeyedEvent
Japan Sanitary Equipment Industry Association: Japan's warm water shower toilet seat shipments reached 100 million sets
开启防火墙iptable规则后,系统网络变慢
[Tool Configuration] Summary of Common Uses of VSCode
女生做软件测试会不会成为一个趋势?









