当前位置:网站首页>3. MNIST dataset classification
3. MNIST dataset classification
2022-07-08 01:02:00 【booze-J】
article
One 、MNIST Data sets and Softmax
1.MNIST Data sets
Most examples use handwritten digits MNIST Data sets . The dataset contains 60,000 Examples and for training 10,000 An example for testing .
Each picture contains 28*28 Pixel , stay MNIST The training data set is a shape of [60000,28,28] Tensor , We first need to convert the data set into [60000,784], Then it can be put into the network for training . The first dimension number is used to index images , The second dimension number is used to index the pixels in each image . Generally, we also need to normalize the data in the picture 0~1 Between .
MNIST The label of the dataset is between 0-9 The number of , We need to convert labels into "one-hotvectors". One one-hot A vector other than one digit is 1 Outside , The other dimension numbers are 0, Such as tag 0 Will be expressed as ([1,0,0,0,0,0,0,0,0,0]), label 3 Will be expressed as ([0,0,0,1,0,0,0,0,0,0]).
therefore ,MNIST The label of the dataset is a [60000,10] Digital matrix of .
28*28=784, Each picture has 784 Pixels , Corresponding 784 Neurons . The final output 10 Neurons correspond to 10 A digital .
2.Softmax
Softmax The function is to convert the output of neural network into probability value .
We know MNIST The result is 0-9, Our model may infer the number of a picture 9 Is the probability that 80%, It's the number. 8 Is the probability that 10%, Then the probability of other numbers is smaller , The total probability adds up to 1. This is a use of softmax Classic case of regression model .softmax The model can be used to assign probabilities to different objects .
Two 、MNIST Data set classification
The code running platform is jupyter-notebook, Code blocks in the article , According to jupyter-notebook Written in the order of division in , Run article code , Glue directly into jupyter-notebook that will do .
1. Import third-party library
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
2. Loading data and data preprocessing
# Load data
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) Not yet one-hot code You need to operate by yourself later
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape() Middle parameter filling -1 Parameter results can be calculated automatically Divide 255.0 To normalize
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# in one hot Format
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3. Training models
# Creating models Input 784 Neurons , Output 10 Neurons
model = Sequential([
# Define output yes 10 Input is 784, Set offset to 1, add to softmax Activation function
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# Define optimizer
sgd = SGD(lr=0.2)
# Define optimizer ,loss_function, The accuracy of calculation during training
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
# Training models
model.fit(x_train,y_train,batch_size=32,epochs=10)
# Evaluation model
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
The final run result :
Be careful
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")
It's used here softmax Activation function .- Here we use
fit
Method of model training , The previous linear regression and nonlinear regression model training methods are different from this .
Code :
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
Add metrics=['accuracy']
, You can calculate the accuracy rate during training .
边栏推荐
- 130. 被围绕的区域
- [reprint] solve the problem that CONDA installs pytorch too slowly
- 完整的模型验证(测试,demo)套路
- Serial port receives a packet of data
- Hotel
- Class head up rate detection based on face recognition
- 14.绘制网络模型结构
- 股票开户免费办理佣金最低的券商,手机上开户安全吗
- jemter分布式
- Interface test advanced interface script use - apipost (pre / post execution script)
猜你喜欢
They gathered at the 2022 ecug con just for "China's technological power"
Y59. Chapter III kubernetes from entry to proficiency - continuous integration and deployment (III, II)
Codeforces Round #804 (Div. 2)(A~D)
Binder core API
Redis, do you understand the list
Jemter distributed
C # generics and performance comparison
Su embedded training - Day9
y59.第三章 Kubernetes从入门到精通 -- 持续集成与部署(三二)
Course of causality, taught by Jonas Peters, University of Copenhagen
随机推荐
C # generics and performance comparison
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
ABAP ALV LVC template
12.RNN应用于手写数字识别
串口接收一包数据
STL--String类的常用功能复写
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
Deep dive kotlin synergy (XXII): flow treatment
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
NTT template for Tourism
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
7. Regularization application
130. Surrounding area
My best game based on wechat applet development
国内首次,3位清华姚班本科生斩获STOC最佳学生论文奖
8.优化器
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
Redis, do you understand the list