当前位置:网站首页>LSTM applied to MNIST dataset classification (compared with CNN)
LSTM applied to MNIST dataset classification (compared with CNN)
2022-07-05 10:49:00 【Don't wait for shy brother to develop】
LSTM be applied to MNIST Data set classification
1、 summary
LSTM Networks are sequential models , Generally, it is more suitable for dealing with sequence problems . Here it is used for the classification of handwritten digital pictures , In fact, it is equivalent to treating pictures as sequences .
a sheet MNIST The picture of the dataset is 28 × 28 28\times 28 28×28 Size , We can think of each line as a sequence input , So a picture is 28 That's ok , The sequence length is 28; Every line has 28 Data , Each sequence input 28 It's worth .
Here we can LSTM and CNN Compare the code results of .
2、LSTM Realization
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import LSTM,Dropout
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
2.1 Loading data sets
# Loading data sets
mnist = tf.keras.datasets.mnist
# Load data , The training set and test set have been divided when the data is loaded
# Training set data x_train The data shape of is (60000,28,28)
# Training set label y_train The data shape of is (60000)
# Test set data x_test The data shape of is (10000,28,28)
# Test set label y_test The data shape of is (10000)
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Normalize the data of training set and test set , It helps to improve the training speed of the model
x_train, x_test = x_train / 255.0, x_test / 255.0
# Turn the labels of training set and test set into single hot code
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# data size - The line has 28 Pixel
input_size = 28
# Sequence length - Altogether 28 That's ok
time_steps = 28
# Hidden layer memory block Number
cell_size = 50
2.2 Creating models
# Creating models
# The data input of the recurrent neural network must be 3 D data
# The data format is ( Number of data , Sequence length , data size )
# Loaded mnist The format of the data just meets the requirements
# Notice the input_shape There is no need to set the quantity of data when setting the model data input
model = Sequential([
LSTM(units=cell_size,input_shape=(time_steps,input_size),return_sequences=True),
Dropout(0.2),
LSTM(cell_size),
Dropout(0.2),
# 50 individual memory block Output 50 Value and output layer 10 Neurons are fully connected
Dense(10,activation=tf.keras.activations.softmax)
])
2.3 Define optimizer
adam = Adam(lr=1e-3)
2.4 Compile model
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
2.5 Training models
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test,y_test))
2.6 Print model summary
model.summary()
2.7 draw acc and loss curve
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# draw loss curve
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# draw acc curve
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
LSTM be applied to MNIST Data recognition can also get good results , But of course, the result is not as good as that of convolutional neural network .
3、CNN Realization
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
# Load data
mnist = tf.keras.datasets.mnist
# Load data , The training set and test set have been divided when the data is loaded
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# Pay attention here , stay tensorflow in , When doing convolution, you need to turn the data into 4 Dimension format
# this 4 The dimensions are ( Number of data , Picture height , Image width , Number of picture channels )
# So here's the data reshape become 4 D data , The number of channels for black-and-white pictures is 1, The number of color picture channels is 3
x_train = x_train.reshape(-1,28,28,1)/255.0
x_test = x_test.reshape(-1,28,28,1)/255.0
# Turn the labels of training set and test set into single hot code
y_train = tf.keras.utils.to_categorical(y_train,num_classes=10)
y_test = tf.keras.utils.to_categorical(y_test,num_classes=10)
# Define sequential model
model = Sequential()
# The first convolution layer
# input_shape input data
# filters Number of filters 32, Generate 32 A feature map
# kernel_size Convolution window size 5*5
# strides step 1
# padding padding The way same/valid
# activation Activation function
model.add(Convolution2D(
input_shape = (28,28,1),
filters = 32,
kernel_size = 5,
strides = 1,
padding = 'same',
activation = 'relu'
))
# The first pool
# pool_size Pool window size 2*2
# strides step 2
# padding padding The way same/valid
model.add(MaxPooling2D(pool_size = 2,strides = 2,padding = 'same'))
# Second convolution layer
# filters Number of filters 64, Generate 64 A feature map
# kernel_size Convolution window size 5*5
# strides step 1
# padding padding The way same/valid
# activation Activation function
model.add(Convolution2D(64,5,strides=1,padding='same',activation='relu'))
# The second pooling layer
# pool_size Pool window size 2*2
# strides step 2
# padding padding The way same/valid
model.add(MaxPooling2D(2,2,'same'))
# Flatten the output of the second pooling layer
# Is equivalent to (64,7,7,64) data ->(64,7*7*64)
model.add(Flatten())
# The first full connection layer
model.add(Dense(1024,activation = 'relu'))
# Dropout
model.add(Dropout(0.5))
# The second full connection layer
model.add(Dense(10,activation='softmax'))
# Define optimizer
adam = Adam(lr=1e-4)
# Define optimizer ,loss function, The accuracy of calculation during training
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])
# Training models
history=model.fit(x_train,y_train,batch_size=64,epochs=10,validation_data=(x_test, y_test))
# Save the model
model.save('mnist_cnn.h5')
# Print model summary
model.summary()
loss=history.history['loss']
val_loss=history.history['val_loss']
accuracy=history.history['accuracy']
val_accuracy=history.history['val_accuracy']
# draw loss curve
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()
# draw acc curve
plt.plot(accuracy, label='Training accuracy')
plt.plot(val_accuracy, label='Validation accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.show()
Model summary
loss curve
acc curve
From the results ,CNN Do than LSTM More suitable for MNIST Classification of data sets .
边栏推荐
猜你喜欢
ModuleNotFoundError: No module named ‘scrapy‘ 终极解决方式
In the year of "mutual entanglement" of mobile phone manufacturers, the "machine sea tactics" failed, and the "slow pace" playing method rose
一次edu证书站的挖掘
磨砺·聚变|知道创宇移动端官网焕新上线,开启数字安全之旅!
[paper reading] ckan: collaborative knowledge aware autonomous network for adviser systems
Ad20 make logo
Learning notes 5 - high precision map solution
2022年危险化学品经营单位主要负责人特种作业证考试题库及答案
重磅:国产IDE发布,由阿里研发,完全开源!
第五届 Polkadot Hackathon 创业大赛全程回顾,获胜项目揭秘!
随机推荐
The first product of Sepp power battery was officially launched
磨礪·聚變|知道創宇移動端官網煥新上線,開啟數字安全之旅!
谈谈对Flink框架中容错机制及状态的一致性的理解
沟通的艺术III:看人之间 之倾听
双向RNN与堆叠的双向RNN
Common functions of go-2-vim IDE
SQL Server monitoring statistics blocking script information
微信核酸检测预约小程序系统毕业设计毕设(7)中期检查报告
Review the whole process of the 5th Polkadot Hackathon entrepreneurship competition, and uncover the secrets of the winning projects!
2021 Shandong provincial competition question bank topic capture
想请教一下,十大券商有哪些?在线开户是安全么?
微信核酸检测预约小程序系统毕业设计毕设(8)毕业设计论文模板
Share Net lightweight ORM
图片懒加载的方案
Go language-1-development environment configuration
Web Components
Node の MongoDB Driver
MFC宠物商店信息管理系统
【tcp】服务器上tcp连接状态json形式输出
脚手架开发基础