当前位置:网站首页>10. CNN applied to handwritten digit recognition
10. CNN applied to handwritten digit recognition
2022-07-08 01:02:00 【booze-J】
The code running platform is jupyter-notebook, Code blocks in the article , According to jupyter-notebook Written in the order of division in , Run article code , Glue directly into jupyter-notebook that will do . The comments given by the overall code are quite simple .
1. Import third-party library
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense,Dropout,Convolution2D,MaxPooling2D,Flatten
from tensorflow.keras.optimizers import Adam
2. Loading data and data preprocessing
# Load data
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) Not yet one-hot code You need to operate by yourself later
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,28,28,1)=( Number of pictures , Picture height , Image width , The number of channels in the picture ) reshape() Middle parameter filling -1 Parameter results can be calculated automatically Divide 255.0 To normalize
# Normalization is critical , It can greatly reduce the amount of calculation
x_train = x_train.reshape(-1,28,28,1)/255.0
x_test = x_test.reshape(-1,28,28,1)/255.0
# in one hot Format
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3. Training models
# Define sequential model
model = Sequential()
# The first convolution layer Note that the first layer should write the size of the input image Later layers can be ignored
# input_shape Input plane
# filters Convolution kernel / Number of filters
# kernel_size Convolution window size
# strides step
# padding padding The way same/valid
# activation Activation function
model.add(Convolution2D(
input_shape=(28,28,1),
filters=32,
kernel_size=5,
strides=1,
padding="same",
activation="relu"
))
# The first pool
model.add(MaxPooling2D(
pool_size=2,
strides=2,
padding="same"
))
# The second pooling layer
model.add(Convolution2D(filters=64,kernel_size=5,strides=1,padding="same",activation="relu"))
# The second pooling layer
model.add(MaxPooling2D(pool_size=2,strides=2,padding="same"))
# Flatten the output of the second pool layer into 1 dimension
model.add(Flatten())
# The first full connection layer
model.add(Dense(units=1024,activation="relu"))
# Dropout Random selection 50% Neurons are trained
model.add(Dropout(0.5))
# The second full connection layer
model.add(Dense(units=10,activation="softmax"))
# Define optimizer Set the learning rate to 1e-4
adam = Adam(lr=1e-4)
# Define optimizer ,loss function, The accuracy of calculation during training
model.compile(optimizer=adam,loss="categorical_crossentropy",metrics=["accuracy"])
# Training models
model.fit(x_train,y_train,batch_size=64,epochs=10)
# Evaluation model
loss,accuracy=model.evaluate(x_test,y_test)
print("test loss:",loss)
print("test accuracy:",accuracy)
Code run results :
Some points needing attention in the code , Explanations and reminders are also given in the code comments .
Be careful
- The first layer of neural network is to write the size of the input image Later layers can be ignored
边栏推荐
- [OBS] the official configuration is use_ GPU_ Priority effect is true
- 【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
- 13. Enregistrement et chargement des modèles
- A network composed of three convolution layers completes the image classification task of cifar10 data set
- 3.MNIST数据集分类
- 新库上线 | 中国记者信息数据
- Prediction of the victory or defeat of the League of heroes -- simple KFC Colonel
- 【深度学习】AI一键换天
- NVIDIA Jetson测试安装yolox过程记录
- 股票开户免费办理佣金最低的券商,手机上开户安全吗
猜你喜欢
英雄联盟胜负预测--简易肯德基上校
SDNU_ACM_ICPC_2022_Summer_Practice(1~2)
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
130. 被圍繞的區域
利用GPU训练网络模型
12. RNN is applied to handwritten digit recognition
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
12.RNN应用于手写数字识别
New library launched | cnopendata China Time-honored enterprise directory
5g NR system messages
随机推荐
Cve-2022-28346: Django SQL injection vulnerability
Serial port receives a packet of data
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
13.模型的保存和載入
Codeforces Round #804 (Div. 2)
Malware detection method based on convolutional neural network
C # generics and performance comparison
The weight of the product page of the second level classification is low. What if it is not included?
AI zhetianchuan ml novice decision tree
6. Dropout application
Interface test advanced interface script use - apipost (pre / post execution script)
Reentrantlock fair lock source code Chapter 0
新库上线 | 中国记者信息数据
They gathered at the 2022 ecug con just for "China's technological power"
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
130. Zones environnantes
New library launched | cnopendata China Time-honored enterprise directory
Marubeni official website applet configuration tutorial is coming (with detailed steps)
STL -- common function replication of string class
3.MNIST数据集分类