当前位置:网站首页>3.MNIST数据集分类
3.MNIST数据集分类
2022-07-07 23:11:00 【booze-J】
一、MNIST数据集及Softmax
1.MNIST数据集
大多数示例使用手写数字的MNIST数据集。该数据集包含60,000个用于训练的示例和10,000个用于测试的示例。
每一张图片包含28*28个像素,在MNIST训练数据集中是一个形状为[60000,28,28]的张量,我们首先需要把数据集转成[60000,784],然后才能放到网络中训练。第一个维度数字用来索引图片,第二个维度数字用来索引每张图片中的像素点。一般我们还需要把图片中的数据归一化0~1之间。
MNIST数据集的标签是介于0-9的数字,我们要把标签转化为"one-hotvectors"。一个one-hot向量除了一位数字是1外,其余维度数字都是0,比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0])。
因此,MNIST数据集的标签是一个[60000,10]的数字矩阵。
28*28=784,每张图片有784个像素点,对应着784个神经元。最后输出10个神经元对应着10个数字。
2.Softmax
Softmax作用就是把神经网络的输出转化为概率值。
我们知道MNIST的结果是0-9,我们模型可能推测出一张图片的数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
二、MNIST数据集分类
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
2.加载数据及数据预处理
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000, 28, 28)
print("x_shape:\n",x_train.shape)
# (60000,) 还未进行one-hot编码 需要后面自己操作
print("y_shape:\n",y_train.shape)
# (60000, 28, 28) -> (60000,784) reshape()中参数填入-1的话可以自动计算出参数结果 除以255.0是为了归一化
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
# 创建模型 输入784个神经元,输出10个神经元
model = Sequential([
# 定义输出是10 输入是784,设置偏置为1,添加softmax激活函数
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss_function,训练过程中计算准确率
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
# 训练模型
model.fit(x_train,y_train,batch_size=32,epochs=10)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
最终运行结果:
注意
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")
这里用到了softmax激活函数。- 这里我们使用的
fit
方法进行的模型训练,之前的线性回归和非线性回归的模型训练方式和这不同。
代码:
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
中添加metrics=['accuracy']
,可以在训练过程中计算准确率。
边栏推荐
- 国外众测之密码找回漏洞
- Marubeni official website applet configuration tutorial is coming (with detailed steps)
- Leetcode brush questions
- Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
- Class head up rate detection based on face recognition
- 9.卷积神经网络介绍
- jemter分布式
- 1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
- 韦东山第二期课程内容概要
- Stock account opening is free of charge. Is it safe to open an account on your mobile phone
猜你喜欢
2022-07-07: the original array is a monotonic array with numbers greater than 0 and less than or equal to K. there may be equal numbers in it, and the overall trend is increasing. However, the number
基于微信小程序开发的我最在行的小游戏
Kubernetes Static Pod (静态Pod)
1293_FreeRTOS中xTaskResumeAll()接口的实现分析
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
英雄联盟胜负预测--简易肯德基上校
新库上线 | CnOpenData中华老字号企业名录
Password recovery vulnerability of foreign public testing
基于卷积神经网络的恶意软件检测方法
语义分割模型库segmentation_models_pytorch的详细使用介绍
随机推荐
New library launched | cnopendata China Time-honored enterprise directory
5.过拟合,dropout,正则化
牛客基础语法必刷100题之基本类型
My best game based on wechat applet development
【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
Cve-2022-28346: Django SQL injection vulnerability
【笔记】常见组合滤波电路
Deep dive kotlin collaboration (the end of 23): sharedflow and stateflow
Is it safe to open an account on the official website of Huatai Securities?
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
Analysis of 8 classic C language pointer written test questions
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
语义分割模型库segmentation_models_pytorch的详细使用介绍
letcode43:字符串相乘
Kubernetes Static Pod (静态Pod)
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
5G NR 系统消息
Play sonar
The standby database has been delayed. Check that the MRP is wait_ for_ Log, apply after restarting MRP_ Log but wait again later_ for_ log
丸子官网小程序配置教程来了(附详细步骤)