当前位置:网站首页>[neural network] (22) convmixer code reproduction, network analysis, complete tensorflow code attached
[neural network] (22) convmixer code reproduction, network analysis, complete tensorflow code attached
2022-06-10 01:31:00 【Vertical sir】
Hello everyone , Today I'd like to share with you how to use TensorFlow structure ConvMixer Convolutional neural network model .
I found this network by chance , This is a very simple model to implement , But it can achieve better precision performance , More than the Vision Transformer Model , There is a feeling of simplicity .
Address of thesis :https://openreview.net/forum?id=TVHS5Y4dNvM
1. introduction
In recent years Transformer Model in CV The dominant position of convolutional neural network is constantly challenged in the field , There has been the emergence of energy and CNN Wrist wrenching VisionTransformer And epoch-making SwinTransformer. The author of this article mainly aims at VIT Model , He asked a question :ViT Its performance is due to its powerful Transformer Structure produces , Or because of the use of patch Produced as an input representation .
In the paper , The author proves that PatchEmbedding Yes VIT Has a greater impact on accuracy , And put forward a very simple model ConvMixer, Similar in thought to ViT and MLP-Mixer. The model will directly patch As input , Mixed modeling of separation space and channel size , And maintain the same resolution throughout the network .
Even though ConvMixer The design is simple , But experiments have proved ConvMixer It is superior to... In terms of similar parameter counts and data set sizes ViT、MLP-Mixer And some of its variants , And classic visual models , Such as ResNet.
2. model building
Let's import the toolkit we need
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers2.1 Patch Embedding
patchembedding The main functions of For the original input image (h, w) Divide image blocks . First specify Of each image block size by (patch_size, patch_size), Divide each image into (h//patch_size, w//patch_size) Image blocks .
Its implementation method is Through one kernel_size and stride All equal to patch_size Image blocks are divided by convolution layer .

The code is as follows :
# ---------------------------------------------- #
#(1)patchembedding layer
'''out_channel Represents the number of output channels , patch_size Represents the width and height of each image block '''
# ---------------------------------------------- #
def patchembed(inputs, out_channel, patch_size):
# The convolution kernel size is patch_size*patch_size, In steps of patch_size Standard convolution division of image blocks
x = layers.Conv2D(filters = out_channel, # Number of output channels
kernel_size = patch_size, # Convolution kernel size
strides = patch_size, # Convolution step
padding = 'same', #
use_bias = False)(inputs)
# GELU Activation function 、BN Standardization
x = layers.Activation('gelu')(x)
x = layers.BatchNormalization()(x)
return x2.2 Feature extraction layer
The feature extraction layer here consists of three parts , Deep convolution (depthwise conv)、 Point by point convolution (pointwise conv)、 Residual connection (shortcut). Here's the picture ConvMixer Layer Shown .
On the principle of deep separable convolution , Look at my blog :https://blog.csdn.net/dgvv4/article/details/123476899

First, input the characteristic diagram , after Depth convolution is used to extract the information of the length and width direction of the feature image , The number of convolution kernels is the same as the number of channels of the input characteristic graph , And Input and output characteristic graphs shape identical ; The residuals then connect the inputs and outputs ; And then pass by 1*1 Point by point convolution fusion channel direction information , Its The number of convolution kernels in is the same as the number of output characteristic graphs .
The code is as follows :
# ---------------------------------------------- #
#(2) Single feature extraction module
'''out_channel Represents the number of output channels for point by point convolution , kernel_size Represents the convolution kernel size of depth convolution '''
# ---------------------------------------------- #
def layer(inputs, out_channel, kernel_size):
# 9*9 Feature extraction by deep convolution
x = layers.DepthwiseConv2D(kernel_size = kernel_size, # Convolution kernel size
strides = 1, # No down sampling
padding = 'same', # Before and after convolution size unchanged
use_bias = False)(inputs)
# GELU Activate 、BN Standardization
x = layers.Activation('gelu')(x)
x = layers.BatchNormalization()(x)
# Residual connection
x = x + inputs
# 1*1 Point by point convolution
x = layers.Conv2D(filters = out_channel, # Number of output channels
kernel_size = 1, # 1*1 Convolution
strides = 1)(x)
# GELU Activate 、BN Standardization
x = layers.Activation('gelu')(x)
x = layers.BatchNormalization()(x)
return x
# ---------------------------------------------- #
#(3) Stack multiple feature extraction modules
'''depth Represents the number of stacks '''
# ---------------------------------------------- #
def blocks(x, depth, out_channel, kernel_size):
# Stack multiple feature extraction modules
for _ in range(depth):
x = layer(x, out_channel, kernel_size)
return x2.3 Backbone network
ConvMixer The network structure is very simple . First, the image goes through PatchEmbedding Divide image blocks , And then pass by 12 A feature extraction module , Finally, the output result is obtained through a full connection layer .
Here we build ConvMixer-1536/20 A network model , among 1536 representative patchembedding The number of output channels of the layer ,20 Stands for stacking 20 A feature extraction module , Each image block patch_size The size is 7*7, The convolution kernel size of the deep convolution in the feature extraction module is 9*9

The code is as follows :
# ---------------------------------------------- #
#(4) Backbone network
'''input_shape Represents the size of the input image ( It doesn't contain batch dimension ), num_classes Represents the number of classifications '''
# ---------------------------------------------- #
def convmixer(input_shape, num_classes):
# Construct input layer [b,224,224,3]
inputs = keras.Input(shape=input_shape)
# patchembedding layer [b,224//7,224//7,1536]
x = patchembed(inputs, out_channel=1536, patch_size=7)
# after 20 Feature extraction layers [b,224//7,224//7,1536]
x = blocks(x, depth=20, out_channel=1536, kernel_size=9)
# Global average pooling [b,1536]
x = layers.GlobalAveragePooling2D()(x)
# Full connection classification [b,num_classes]
outputs = layers.Dense(num_classes)(x)
# Build a network
model = keras.Model(inputs, outputs)
return model2.4 View network architecture
With 1000 Take classification as an example to view the network structure
# ---------------------------------------------- #
#(5) Look at the network structure
# ---------------------------------------------- #
if __name__ == '__main__':
# Accept model
model = convmixer(input_shape=[224,224,3],num_classes=1000)
# Look at the network structure
model.summary()The network structure is as follows :
conv2d_20 (Conv2D) (None, 32, 32, 1536 2360832 ['tf.__operators__.add_19[0][0]']
)
activation_40 (Activation) (None, 32, 32, 1536 0 ['conv2d_20[0][0]']
)
batch_normalization_40 (BatchN (None, 32, 32, 1536 6144 ['activation_40[0][0]']
ormalization) )
global_average_pooling2d (Glob (None, 1536) 0 ['batch_normalization_40[0][0]']
alAveragePooling2D)
dense (Dense) (None, 1000) 1537000 ['global_average_pooling2d[0][0]'
]
==================================================================================================
Total params: 51,719,656
Trainable params: 51,593,704
Non-trainable params: 125,952
__________________________________________________________________________________________________边栏推荐
- Design and application of fire emergency lighting and evacuation indication system in a clean medicine
- [email protected] 项目实训
- Locust:微服务性能测试利器
- [learn FPGA programming from scratch -16]: quick start chapter - operation steps 2-4- basic syntax of Verilog HDL language description language (both software programmers and hardware engineers can un
- Domain Adaptation and Graph Neural Networks
- Dependent auto assembly
- Cocoscreator old, live and new - synthetic large zongzi
- My eclipse connection database
- Where is safe and reliable for Hangzhou futures to open an account?
- Application scheme of residual pressure monitoring system in a high-rise residential building
猜你喜欢

How can win11 directly return to the desktop?

Introduction to cross platform multimedia rendering engine OPR

Various utilization forms of map tile data and browsing display of tile data

The project was successful, and the project manager was the greatest contributor?

CocosCreator旧活新整-合成大粽子

生产线往越南转移未必是好选择,三星手机已受累并计划回归韩国

孙宇晨等收购Poloniex,公链交易所双轮驱动波场生态

1265_FreeRTOS中向任务就绪链表增加任务的实现分析

Teaching Broad Reasoning Skills via Decomposition-Guided Contexts

【图像分类案例】(10) Vision Transformer 动物图像三分类,附Pytorch完整代码
随机推荐
Teaching Broad Reasoning Skills via Decomposition-Guided Contexts
Source code analysis of Tencent libco collaboration open source library (I) -- download libco compilation and installation and try to run the sample code
花了两小时体验IDEA最新史诗皮肤
基于OpenVINO部署的工业缺陷检测产业实践范例实战
为什么芯片设计也需要「匠人精神」?
Web3对元宇宙的数据主权至关重要
[email protected] 项目实训
【FPGA】day15-串口协议uart回环工程
把 GPL 视作“病毒”?请停止污名化 GPL
【FPGA】day16-FIFO实现uart协议
騰訊Libco協程開源庫 源碼分析(一)---- 下載Libco 編譯安裝 嘗試運行示例代碼
Eight consecutive championships! Inspur cloud has been the first in China's Government cloud market for 8 consecutive years
【ACL 2022】Hallucinated but Factual! Inspecting the Factuality of Hallucinations
【LeetCode】287. 寻找重复数
Application of fire door monitoring system in a residential project
From the perspective of advanced automated testing, one article with 7 stages and 5000 words will give you a comprehensive understanding of automated testing
实验三 字符类型及其操作(新)
Fpga-vga display
[email protected] 项目实训
Associative array & regular expression