当前位置:网站首页>3.MNIST数据集分类
3.MNIST数据集分类
2022-07-07 23:11:00 【booze-J】
一、MNIST数据集及Softmax
1.MNIST数据集
大多数示例使用手写数字的MNIST数据集。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。
每一张图片包含28*28个像素,在MNIST训练数据集中是一个形状为[60000,28,28]的张量,我们首先需要把数据集转成[60000,784],然后才能放到网络中训练。第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。一般我们还需要把图片中的数据归一化0~1之间。
MNIST数据集的标签是介于0-9的数字,我们要把标签转化为"one-hotvectors"。一个one-hot向量除了一位数字是1外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0])。
因此,MNIST数据集的标签是一个[60000,10]的数字矩阵。
28*28=784,每张图片有784个像素点,对应着784个神经元。最后输出10个神经元对应着10个数字。
2.Softmax
Softmax作用就是把神经网络的输出转化为概率值。
我们知道MNIST的结果是0-9,我们模型可能推测出一张图片的数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
二、MNIST数据集分类
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
最终运行结果:
注意
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")
这里用到了softmax激活函数。- 这里我们使用的
fit
方法进行的模型训练,之前的线性回归和非线性回归的模型训练方式和这不同。
代码:
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
中添加metrics=['accuracy']
,可以在训练过程中计算准确率。
边栏推荐
- Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
- LeetCode刷题
- QT establish signal slots between different classes and transfer parameters
- New library online | cnopendata China Star Hotel data
- Image data preprocessing
- 【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
- 基于微信小程序开发的我最在行的小游戏
- 新库上线 | CnOpenData中国星级酒店数据
- Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
- 接口测试进阶接口脚本使用—apipost(预/后执行脚本)
猜你喜欢
Service Mesh介绍,Istio概述
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
v-for遍历元素样式失效
Qt不同类之间建立信号槽,并传递参数
Binder core API
Interface test advanced interface script use - apipost (pre / post execution script)
ReentrantLock 公平锁源码 第0篇
Jouer sonar
An error is reported during the process of setting up ADG. Rman-03009 ora-03113
C# 泛型及性能比较
随机推荐
The method of server defense against DDoS, Hangzhou advanced anti DDoS IP section 103.219.39 x
Kubernetes Static Pod (静态Pod)
Service Mesh的基本模式
jemter分布式
An error is reported during the process of setting up ADG. Rman-03009 ora-03113
NVIDIA Jetson测试安装yolox过程记录
Summary of the third course of weidongshan
Which securities company has a low, safe and reliable account opening commission
Reentrantlock fair lock source code Chapter 0
4.交叉熵
A network composed of three convolution layers completes the image classification task of cifar10 data set
Deep dive kotlin synergy (XXII): flow treatment
什么是负载均衡?DNS如何实现负载均衡?
Leetcode brush questions
英雄联盟胜负预测--简易肯德基上校
How to learn a new technology (programming language)
How to insert highlighted code blocks in WPS and word
Installation and configuration of sublime Text3
深潜Kotlin协程(二十二):Flow的处理
接口测试要测试什么?