当前位置:网站首页>MobileNetV1架构解析
MobileNetV1架构解析
2022-08-05 07:03:00 【别团等shy哥发育】
MobileNetV1架构解析
参考论文:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications
1、简介
MobileNets基于一种流线型架构,使用深度可分离卷积构建轻量级深度神经网络。我们引入了两个简单的全局超参数,可以有效地在延迟和准确性之间进行权衡。这些超参数允许模型生成器根据问题的约束为其应用程序选择适当大小的模型。我们在资源和准确性权衡方面进行了大量实验,与其他流行的ImageNet分类模型相比,我们表现出了强大的性能。然后,我们展示了MobileNet在广泛的应用和用例中的有效性,包括目标检测、精细分类、人脸属性和大规模地理定位。
MobileNet模型可应用于各种识别任务,以实现高效的设备智能。
2、Model Architecture
2.1 Depthwise Separable Convolution
MobileNet模型基于深度可分离卷积,这是一种分解卷积的形式,将标准卷积分解为深度卷积和1*1
的点卷积。对于MobileNet,深度卷积将单个滤波器应用于每个输入通道,然后,逐点卷积应用1*1
卷积将输出与深度卷积相结合。
标准卷积在一个步骤中将输入滤波并组合成一组新的输出。深度可分离卷积将标准的卷积层分解为两层来做:
- 首先是各个通道单独做卷积运算,称之为Depthwise Convolution
- 然后用一个
1*1
的标准卷积层进行各个通道间的合并,称之为Pointwise Convolution
论文中原图如下所示:
论文中将标准卷积(a)分级为深度卷积(b)和1*1
逐点卷积的图如下:
2.2 Network Structure
MobileNet结构建立在深度可分离卷积上,第一层是全卷积。MobileNet架构如下表所示,除过最后的全连接层,所有层后面都是BatchNorm和ReLU非线性激活函数,最后的全连接层没有非线性,并馈送到softmax层进行分类。将深度卷积和点卷积计算为单独的层,MobileNet有28层。
3、传统卷积与深度可分离卷积图解
3.1 传统卷积
- 卷积核channel=输入特征矩阵channel
- 输出特征矩阵channel-卷积核个数
3.2 Depthwise卷积
- 卷积核channel=1
- 输入特征矩阵channel=卷积核个数=输出特征矩阵channel
DW卷积中的每一个卷积核只会和输入特征矩阵的一个channel进行卷积计算,所以输出的特征矩阵就等于输入的特征矩阵。
3.3 Pointwise卷积
Pointwise卷积和普通的卷积一样,只不过使用了1*1
卷积核。
3.3 Depthwise Separable Convolution(深度可分离卷积)
深度可分离卷积由Depthwise卷积和Pointwise卷积两部分组成
4、Width Multiplier: Thinner Models
虽然基本MobileNet架构已经很小且延迟很低,但很多时候,特定用例或应用程序可能需要更小更快的模型。为了构造这些较小且计算成本较低的模型,我们引入了一个非常简单的参数α,称为宽度乘数。宽度倍增器α的作用是在每一层均匀地薄化网络。对于给定的层和宽度乘数α,输入通道数M变为αM,输出通道数N变为αN。
5、Resolution Multiplier:Reduced Representation
降低神经网络计算成本的第二个超参数是分辨率乘数ρ。我们将其应用于输入图像,然后通过相同的乘法器减少每个层的内部表示。
表3显示了当架构收缩方法顺序应用于层时,层的计算和参数数量。第一行显示了全卷积层的Mult加法和参数,输入特征图的大小为14×14×512,核K的大小为3×3×512×512。
6、模型计算量与精度之间的权衡
首先,我们展示了具有深度可分离卷积的MobileNet与使用全卷积构建的模型相比的结果。在表4中,我们看到,与全卷积相比,使用深度可分离卷积在ImageNet上仅降低了1%的准确度,并在多个加法和参数上节省了大量成本。
表5显示,在类似的计算和参数数量下,使MobileNets变薄比使其变浅要好3%。
表6显示了使用宽度乘数α缩小MobileNet架构的精度、计算和尺寸权衡。精度平稳下降,直到在α=0.25时架构变得太小。
表7显示了通过使用降低的输入分辨率训练MobileNet,不同分辨率乘法器的精度、计算和大小权衡。精度在整个分辨率范围内平稳下降。
7、模型搭建(Tensorflow2.0)
这里是手动搭建的,你也可以直接使用迁移学习相关的API。
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.layers import DepthwiseConv2D,BatchNormalization,ReLU,Conv2D
from tensorflow.keras.layers import GlobalAvgPool2D,Dense
# 定义可分离卷积块结构
def depthwise_conv_block(inputs,pointwise_conv_filters,strides=(1,1)):
x=DepthwiseConv2D((3,3),padding='same',strides=strides,use_bias=False)(inputs)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
x=Conv2D(pointwise_conv_filters,kernel_size=(1,1),padding='same',use_bias=False)(x)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
return x
# 定义MobileNet第一层普通卷积结构
def conv_block(inputs,filters,kernel_size=(3,3),strides=(1,1)):
x=Conv2D(filters,kernel_size=kernel_size,strides=strides,padding='same',
use_bias=False)(inputs)
x=BatchNormalization()(x)
x=ReLU(6.0)(x)
return x
def mobilenet_v1(inputs,classes):
# 第一层普通卷积
#[32,32,3]=>[16,16,32]
x=conv_block(inputs,32,strides=(2,2))
# [16,16,32]=>[16,16,64]
x=depthwise_conv_block(x,64)
# [16,16,64]=>[8,8,128]
x=depthwise_conv_block(x,128,strides=(2,2))
# [8,8,128]=>[8,8,128]
x=depthwise_conv_block(x,128)
# [8,8,128]=>[4,4,256]
x=depthwise_conv_block(x,256,strides=(2,2))
# [4,4,256]=>[4,4,256]
x=depthwise_conv_block(x,256)
# [4,4,256]=>[2,2,512]
x=depthwise_conv_block(x,512,strides=(2,2))
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[2,2,512]
x=depthwise_conv_block(x,512)
# [2,2,512]=>[1,1,1024]
x=depthwise_conv_block(x,1024,strides=(2,2))
# [1,1,1024]=>[1,1,1024]
x=depthwise_conv_block(x,1024)
# [1,1,1024]=>[1024,]
x=GlobalAvgPool2D()(x)
# [1024,]=>[classes,]
x=Dense(classes,activation='softmax')(x)
return x
这里顺便在CIFAR10数据集上面测试一下。
INPUT_WIDTH=32
INPUT_HEIGHT=32
N_CHANNELS=3
N_CLASSES=10
batch_size=128
epochs=10
inputs=tf.keras.Input(shape=(INPUT_WIDTH,INPUT_HEIGHT,N_CHANNELS))
outputs=mobilenet_v1(inputs,N_CLASSES)
model=tf.keras.Model(inputs=inputs,outputs=outputs)
model.summary()
# 数据准备
(x_train,y_train),(x_test,y_test)=tf.keras.datasets.cifar10.load_data()
# 归一化
x_train,x_test=x_train/255.0,x_test/255.0
# 转为独热编码
y_train=tf.keras.utils.to_categorical(y_train,N_CLASSES)
y_test=tf.keras.utils.to_categorical(y_test,N_CLASSES)
# 模型编译
adam=tf.keras.optimizers.Adam(1e-4)
model.compile(optimizers=adam,loss='categorical_crossentropy',
metrics=['accuracy'])
# 模型训练
history = model.fit(
x_train,
y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
validation_freq=1
)
accuracy = history.history["accuracy"]
val_accuracy = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]
epochs = range(1, len(accuracy) + 1)
fig,ax = plt.subplots(1,2,figsize=(12,6)) # figsize=(width, height)
ax[0].plot(epochs, accuracy, "bo", label="Training accuracy")
ax[0].plot(epochs, val_accuracy, "b", label="Validation accuracy")
ax[0].set_title("Training and validation accuracy")
ax[0].legend()
ax[1].plot(epochs, loss, "bo", label="Training loss")
ax[1].plot(epochs, val_loss, "b", label="Validation loss")
ax[1].set_title("Training and validation loss")
ax[1].legend()
References
Howard A G , Zhu M , Chen B , et al. MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications[J]. 2017.
边栏推荐
- After the firewall iptable rule is enabled, the system network becomes slow
- 一天学会从抓包到接口测试,通过智慧物业项目深度解析
- 【win7】NtWaitForKeyedEvent
- 文本特征化方法总结
- Day9 of Hegong Daqiong team vision team training - camera calibration
- TRACE32——C源码关联1
- C# FileSystemWatcher
- [Shanghai] Hiring .Net Senior Software Engineer & BI Data Warehouse Engineer (Urgent)
- Shared memory + inotify mechanism to achieve multi-process low-latency data sharing
- 不太会讲爱,其实已经偷偷幸福很久啦----我们的故事
猜你喜欢
随机推荐
re正则表达式
DeFi 前景展望:概览主流 DeFi 协议二季度进展
女生做软件测试会不会成为一个趋势?
线程池的使用(结合Future/Callable使用)
LaTeX Notes
Flink Learning 11: Flink Program Parallelism
TRACE32——C源码关联1
Redis
Falsely bamboo brother today and found a localization of API to use tools
Discourse 清理存储空间的方法
标准C语言15
Technical Analysis Patterns (11) How to Trade Head and Shoulders Patterns
Summary of Text Characterization Methods
访问被拒绝:“microsoft.web.ui.webcontrols”的解决办法
props 后面的数据流是什么?
RK3568 environment installation
typescript64-映射类型
TCP的粘包拆包问题+解决方案
MySQL:order by排序查询,group by分组查询
字节面试流程及面试题无私奉献,吐血整理