当前位置:网站首页>实现mnist手写数字识别
实现mnist手写数字识别
2022-06-24 19:50:00 【ㄣ知冷煖*】
前言
实现mnist手写数字识别一、代码实现
import tensorflow as tf
from tensorflow.keras.datasets import mnist
import matplotlib as plt
from tensorflow.keras import models
from tensorflow.keras import layers
(train_images,train_labels), (test_images, test_labels) = mnist.load_data()
# train_images.shape: (60000,28,28) 6万张图像,每一张图像都是28*28的像素图片。
# 构建神经网络
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
# 几分类就写几,这里是10分类。
network.add(layers.Dense(10, activation='softmax'))
# compile(编译):损失函数、优化器、在训练和测试过程中需要监控的指标
# metrics:指标列表,对于分类问题,我们一般将该列表设置为metrics=['accuracy'],均方误差回归损失用mse
# 多分类损失用'categorical_crossentropy',二分类损失用'binary_crossentropy'
network.compile(optimizer='rmsprop',
loss='categorical_crossentropy',
metrics=['accuracy'])
# 数据处理:将其变换为网络要求的形状,并且进行归一化
train_images = train_images.reshape((60000, 28*28))
train_images = train_images.astype('float32')/255
test_images = test_images.reshape((10000, 28*28))
test_images = test_images.astype('float32')/255
from tensorflow.keras.utils import to_categorical
# to_categorical:将类别向量转换为二进制(只有0和1)的矩阵类型表示。即将原有的类别向量转换为独热编码的形式。
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 开始训练
network.fit(train_images, train_labels, epochs=20, batch_size=128)
# 评估
test_loss, test_acc = network.evaluate(test_images, test_labels)
print(test_loss, test_acc)
二、一些注意问题
2-1、网络构建方式
可以是(通过构建器创建):
network = models.Sequential()
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
network.add(layers.Dense(10, activation='softmax'))
也可以是(通过add方法构建):
network = models.Sequential([
layers.Dense(512, activation='relu', input_shape=(28*28,)),
layers.Dense(10, activation='softmax'),
])
2-2、确定模型输入数据的规格
第一层需要通过参数传递告知模型数据规格,后边的层不需要,因为可以自动的根据第一层的输出进行推导。
通过input_shape参数:
network.add(layers.Dense(512, activation='relu', input_shape=(28*28,)))
也可以通过input_dim参数设定,和上边的含义类似:
network.add(layers.Dense(512, activation='relu', input_dim=28*28))
注意:input_shape=(2828,)代表的是输入的数据是2828维的一阶向量。input_shape的格式是元组,所以必须写为(28*28,)这种形式。
2-3、全连接层内部的张量运算
例子:
keras.layers.Dense(512, activation='relu')
注解:输入一个2D张量,返回另一个2D张量。函数如下所示
公式表示:output = relu(dot(w, input) + b)
即:输入张量和张量w(给定形状的随机张量)之间的点积运算(dot),得到的2D张量与向量b之间的加法运算,最后经过relu激活函数(即max(x,0)),relu运算和加法运算都是逐元素运算。
2-4、有关于点积的一些理解
keras.layers.Dense(512, activation='relu')
注意:两个向量之间的点积是一个标量,而且只有元素个数相同的向量之间才能做点积, 逐元素相乘然后相加。
import numpy as np
np.dot([1, 2],[3,4])
# 输出
# 11
一般:两个矩阵之间的点积,对于两个矩阵x和y,当且仅当x.shape[1] == y.shape[0] 时,你才可以对它们做点积,得到的结果是一个形状为(x.shape[0], y.shape[1])的矩阵,即x的行与y的列相乘后的和相加。
np.dot([[1, 2],[1,2]], [[3, 4],[3,4]])
# 输出
# array([[ 9, 12],
# [ 9, 12]])
参考文章:
通过Sequential快速搭建tensorflow模型.
Input_shape参数.
Keras中文文档.
优化器optimizers.
目标函数objectives.
Sequential模型方法.
总结
有些事情努力了就好啦,虽然结果惨不忍睹。。。
边栏推荐
- 无人驾驶: 对多传感器融合的一些思考
- Report on operation mode and future development trend of global and Chinese propenyl isovalerate industry from 2022 to 2028
- 5-minute NLP: summary of 3 pre training libraries for rapid realization of NER
- Report on operation pattern and future prospect of global and Chinese propyl isovalerate industry from 2022 to 2028
- MySQL日志管理
- Alternative to log4j
- Collection of software testing and game testing articles
- How to delete the entire row with duplicate items in a column of WPS table
- Encryption and encoding resolution
- Go crawler framework -colly actual combat (III) -- panoramic cartoon picture capture and download
猜你喜欢

MySQL日志管理

Tutorial details | how to edit and set the navigation function in the coolman system?

Im instant messaging development application keeping alive process anti kill

Use of navigation and navigationui

Encryption and encoding resolution

JPA learning 1 - overview, JPA, JPA core annotations, JPA core objects

Discrete mathematics and its application detailed explanation of exercises in the final exam of spring and summer semester of 2018-2019 academic year

Use coordinatorlayout+appbarlayout+collapsingtoolbarlayout to create a collapsed status bar

C WinForm maximizes occlusion of the taskbar and full screen display

Reservoir dam safety monitoring
随机推荐
信号完整性(SI)电源完整性(PI)学习笔记(二十五)差分对与差分阻抗(五)
WordPress add photo album function [advanced custom fields Pro custom fields plug-in series tutorial]
Color gradient gradient color collection
Use of JMeter
Some examples of MgO operating database in go
MySQL log management
U.S. House of Representatives: digital dollar will support the U.S. dollar as the global reserve currency
Microsoft won the title of "leader" in the magic quadrant of Gartner industrial Internet of things platform again!
Hyperledger Fabric 2. X dynamic update smart contract
Use and click of multitypeadapter in recycleview
Go crawler framework -colly actual combat (III) -- panoramic cartoon picture capture and download
Signal integrity (SI) power integrity (PI) learning notes (XXV) differential pair and differential impedance (V)
Current situation and development prospect forecast report of global and Chinese tetrahydrofurfuryl alcohol acetate industry from 2022 to 2028
Use of navigation and navigationui
The drawableleft of the custom textview in kotlin is displayed in the center together with the text
Usage of ViewModel and livedata in jetpack
How to delete the entire row with duplicate items in a column of WPS table
【图数据库性能和场景测试利器LDBC SNB】系列一:数据生成器简介 & 应用于GES服务
Tape SVG animation JS effect
Analysis report on operation trend and investment strategy of global and Chinese tetrahydrofurfuryl propionate industry from 2022 to 2028