当前位置:网站首页>TensorFlow2的Conv1D, Conv2D,Conv3D机器对应的MaxPooling详解
TensorFlow2的Conv1D, Conv2D,Conv3D机器对应的MaxPooling详解
2022-06-13 01:37:00 【星空下0516】
TensorFlow2对Conv1D, Conv2D, Conv3D都有详细的解释,针对公有的
卷积操作
机器学习中卷积,顾名思义有卷的动作:
s ( t ) = ∫ x ( a ) w ( t − a ) d a s(t) = \int x(a)w(t-a)da s(t)=∫x(a)w(t−a)da
这里x一般叫做输入数据,而w叫做卷积核。从公式可以看出,w在对x进行卷积,卷积操作是对x进行倒序相乘,如图:
注意图片中卷积核的红色角标,即上述卷积公式的操作。
主要参数进行说明一下:
- filters: 卷积核数目,卷积计算时折射使用的空间维度;
- kernel_size: 卷积核大小,要求是一个Tensor,具有[filter_height, filter_width, in_channels, out_channels]这样的shape, 具体含义是[卷积核高度,卷积核宽度,图像通道数,卷积核个数],要求类型与参数input相同。有一个地方需要注意的是,第三维in_channels就是input的第四维;
- strides: 步进大小,卷积时在图像每一维的步长。这是一个一维的向量,第一维和第四维默认是1,而第三维和第四维分别是水平和垂直方向滑行的步进长度;
- padding: 补全方式,string类型的量,只能是“SAME”和“VALID”其中之一,这个值决定了不同的卷积方式;
- activation: 激活函数,一般使用ReLU作为激活函数。
经过卷积后输出数据的大小,即conv大小:
N = ( W − F + 2 P ) / S + 1 ( P a d d i n g = “ S A M E ” ) N = ( W − F ) / S + 1 ( P a d d i n g = “ V A L I D ” ) N=(W-F+2P)/S+1(Padding=“SAME”)\\ N=(W-F)/S+1(Padding=“VALID”)\\ N=(W−F+2P)/S+1(Padding=“SAME”)N=(W−F)/S+1(Padding=“VALID”)
这里各个参数表示的意义如下:
- N: 输出数据的大小, 即经过一次卷积后输出的尺寸
- W: 输入图片的尺寸, 即input对应的值
- F: 卷积核大小,即kernel_size对应的值
- P: Padding的像素数,SAME模式下一般情况下取1,VALID模式下取0,如果不给该值则P=0
- S: 步长S,即strides对应的值,默认值为1
那么我们就可以从上面的公式得到N,即输出的数据为2×2的尺寸。如果我们选择SAME补齐,会结果就会得到3×3的尺寸,而选择VALID,则输出2×2。
举例说明一下:
Conv1D
import tensorflow as tf
input = tf.Variable(tf.random.normal([1, 3, 1])) # 随机输入一个3×1的数组
conv = tf.keras.layers.Conv1D(1, 2)(input) # 使用1个2x1大小的卷积核进行卷积,这里省略了步进(默认是1,未给出),padding(这里没有使用padding)等。
print(conv)
输入尺寸是:3x1,根据上述公式,输出大小为2x1。
import tensorflow as tf
input = tf.Variable(tf.random.normal([1, 3, 1])) # 随机输入一个3×1的数组
conv = tf.keras.layers.Conv1D(1, 2, padding="valid")(input) # 使用1个2×1大小的卷积核进行卷积,这里省略了步进(默认是1,未给出),padding(valid模式)等。
print(conv)
输入尺寸是:3x1,根据上述公式,输出大小为2x1,Padding采用valid模式。
import tensorflow as tf
input = tf.Variable(tf.random.normal([1, 3, 3, 1])) # 随机输入一个3×1的数组
conv = tf.keras.layers.Conv1D(1, 2, padding="same")(input) # 使用1个2×1大小的卷积核进行卷积,这里省略了步进(默认是1,未给出),padding(same模式)等。
print(conv)
输入尺寸是:3x1,根据上述公式,输出大小为3x1,Padding采用same模式。
Conv2D
import tensorflow as tf
input = tf.Variable(tf.random.normal([1, 3, 3, 1])) # 随机输入一个3×3的数组
conv = tf.keras.layers.Conv2D(1, 2)(input) # 使用1个2x2大小的卷积核进行卷积,这里省略了步进(默认是1,未给出),padding(这里没有使用padding)等。
print(conv)
输入尺寸是:3x3,根据上述公式,输出大小为2x2。其他模式根据公式可以推出。
Conv3D
import tensorflow as tf
input = tf.Variable(tf.random.normal([1, 3, 3, 3, 1])) # 随机输入一个3x3x3的数组
conv = tf.keras.layers.Conv3D(1, 2)(input) # 使用1个2x2x2大小的卷积核进行卷积,这里省略了步进(默认是1,未给出),padding(这里没有使用padding)等。
print(conv)
输入尺寸是:3x3x3,根据上述公式,输出大小为2x2x2。其他模式根据公式可以推出。
池化操作
池化操作是为了防止过拟合,主要有两种池化方式:平均池化和最大池化。
- 平均池化:将池化窗口的数值求平均,使用这个平均值作为该窗口的值;
- 最大池化:将池化窗口的数值求最大值,使用这个最大值作为该窗口的值。
池化操作示意图:
重要参数设置:
- pool_size: 池化窗口的大小,默认一般是[2, 2]
- strides: 和卷积类似,表示窗口每个维度上滑动的步长,默认一般是[2, 2]
- padding: 和卷积类似,可以采用"SAME"和"VALID"两种模式,返回一个Tensor,shape依然是[batch, height, width, channels]类型。
池化后图像尺寸变化公式:
N = W − P s S + 1 N= \frac{W-P_s}{S}+1 N=SW−Ps+1
这里:
- N: 表示池化前输出图片的尺寸大小;
- W: 表示池化后输入图片的尺寸大小;
- S: 表示滑动步长;
- Ps: 表示池化尺寸。
边栏推荐
- Leetcode question brushing 07 double pointer
- Network communication tcp/ip
- Run Presto under docker to access redis and Bi presentation
- How to solve the problems when using TV focusable to package APK in uni app
- [Stanford Jiwang cs144 project] lab1: streamreassembler
- Three paradigms of database
- 关于tkinter.Canvas 不显示图片的问题
- About tkinter Canvas does not display pictures
- Method of cleaning C disk
- Set and array conversion, list, array
猜你喜欢

ES6解构赋值

Run Presto under docker to access redis and Bi presentation

Leetcode question brushing 02 linked list operation
![[projet cs144 de Stanford Computing Network] lab1: Stream reassembler](/img/7b/fad18b68a6ee30d1dec4dad6273b98.png)
[projet cs144 de Stanford Computing Network] lab1: Stream reassembler

September 3, 2021 visual notes

MySQL related summary
![[andoid][step pit]cts 11_ Testbootclasspathandsystemserverclasspath at the beginning of R3_ Analysis of nonduplicateclasses fail](/img/b5/7ea603775dc0448368d209de037a43.png)
[andoid][step pit]cts 11_ Testbootclasspathandsystemserverclasspath at the beginning of R3_ Analysis of nonduplicateclasses fail
![[leetcode] valid phone number Bash](/img/f8/cecb74f9d8f7c589e62e3a9a874f82.jpg)
[leetcode] valid phone number Bash

C language implementation of the classic eight queens problem

受众群体应该选择观察模式还是定位模式?
随机推荐
工作与生活
Answer to matrix theory of Nanjing University of Aeronautics and Astronautics
Differences among bio, NiO and AIO
谷歌的受众群体是如何发挥作用的?
Loss calculation in pytorch
项目实训(十七)---个人工作总结
About retrieving ignored files in cornerstone
Understanding of the detach() function of pytorch
Rasa dialogue robot helpdesk (III)
leetcode. 151. flip the words in the string
Startup, connection and stop of MySQL service
About the proposed signature file migration to industry standard format pkcs12
ng-tv-focusable
Realization of flip animation
Stone from another mountain: Web3 investment territory of a16z
Leetcode question brushing 03 stack
【MathType】利用MathType输出LaTex样式的公式
On February 26, 2022, the latest news of national oil price adjustment today
MySQL - use field alias after where
np. Understanding of axis in concatenate