当前位置:网站首页>实现mnist手写数字识别
实现mnist手写数字识别
2022-08-01 16:58:00 【重邮研究森】
活动地址:CSDN21天学习挑战赛
创作计划
**
1,机缘
写csdn也快一年了,最初的原因是因为我师兄,他写了很大有用的文章,当时我觉得非常帅!!并且通过这个方式也可以帮助我解决不记笔记的问题。通过写文章不仅可以记录笔记,也可以通过记录的方式加强我的学习体验。
2,收获
收获了很多的有用的文章,也收获了几百个小伙伴的关注!!!
3,日常
我现在刷题的话,也会把题目进行记录,所以记录已经成为了我的日常,我的学习内容都会写到这里
4,憧憬
希望毕业能够留在成都,年薪30
**
目录
文中一些基础概念的知识在我写的这个文章中可以查阅
(2条消息) tensorflow零基础入门学习_重邮研究森的博客-CSDN博客https://blog.csdn.net/m0_60524373/article/details/124143223
1.跑通代码
我这个人对于任何代码,我都会先去跑通之和才会去观看内容,哈哈哈,所以第一步我们先不管37=21,直接把博主的代码复制黏贴一份运行结果。(PS:做了一些修改,因为原文是在jupyter,而我在pycharm)
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
# train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
# plt.figure(figsize=(20,10))
# for i in range(20):
# plt.subplot(5,10,i+1)
# plt.xticks([])
# plt.yticks([])
# plt.grid(False)
# plt.imshow(train_images[i], cmap=plt.cm.binary)
# plt.xlabel(train_labels[i])
# plt.show()
#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))#60k像素,尺寸28x28,通道数=1
test_images = test_images.reshape((10000, 28, 28, 1))
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
"""
输出:((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))
"""
#
model = models.Sequential([
layers.Conv2D()
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), # 卷积层1,卷积核个数32,卷积核尺寸3*3
layers.MaxPooling2D((2, 2)), # 池化层1,2*2采样 (28-2)/stride+1=26
layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层2,卷积核3*3
layers.MaxPooling2D((2, 2)), # 池化层2,2*2采样
layers.Flatten(), # Flatten层,连接卷积层与全连接层
layers.Dense(64, activation='relu'), # 全连接层得到一维向量64,特征进一步提取
layers.Dense(10) # 输出层一维向量10,输出预期结果
])
# model = models.Sequential([
# #图像数据需要拉直
# tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
# tf.keras.layers.Dense(128, activation='relu'),
# tf.keras.layers.Dense(10, activation='softmax')
# ])
# 打印网络结构
model.summary()
#优化器:adam loss函数没经过概率分布 accuracy" : y_ 和 y 都是数值,如y_ = [1] y = [1] #y_为真实值,y为预测值
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
plt.imshow(test_images[1])
plt.show()
pre = model.predict(test_images) # 对所有测试图片进行预测
print( pre[1]) # 输出第一张图片的预测结果
点击pycharm即可运行出最后的预测结果!
2.代码分析
神经网络的整个过程我分为如下六部分,而我们也会对这六部分进行逐部分分析。那么这6部分分别是:
六步法:
1->import
2->train test(指定训练集的输入特征和标签)
3->class MyModel(model) model=Mymodel(搭建网络结构,逐层描述网络)
4->model.compile(选择哪种优化器,损失函数)
5->model.fit(执行训练过程,输入训练集和测试集的特征+标签,batch,迭代次数)
6->验证
2.1
导入:这里很容易理解,也就是导入本次实验内容所需要的各种库。在本案例中主要包括以下部分:
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
主要是tensorflow以及绘制的库。
对于这里的话我们可以直接复制黏贴,当需要一些其他函数时,只需要添加对应的库文件即可。
2.2
设置训练集和测试集:对于神经网络的训练包括了两种数据集合,一个是训练集,一个是测试集。其中训练集数据较多,测试集较少,因为训练一个模型数据越多相对的模型更准确。
本文中的数据集利用了网络的mnist数据集,该数据集是一个手写0-9的数据集合
对于该类数据集的使用,我们利用下面函数即可使用该数据集
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
注意事项:由于本实验中的都是图像数据集,为了使网络训练结果更好,我们需要对图像数据进行标准化。像素点是255个,所以对于数据整除255即可。
train_images, test_images = train_images / 255.0, test_images / 255.0
在进行了标准化之和,我们的图像数据还是不能直接传入,对于网络模型的输入,我们需要让输入数据和网络模型的“入口”保持一致。因此我们还需要把数据进行尺寸修改,这里的修改大小倒是不明确要求。
train_images = train_images.reshape((60000, 28, 28, 1))#60k像素,尺寸28x28,通道数=1
test_images = test_images.reshape((10000, 28, 28, 1))
2.3
网络模型搭建:这里也是神经网络的重点了!废话不多说,直接开始!
本文的神经网络的结构图如下:
在搭建模型的时候,我们将按照这个图片进行模型的搭建。
卷积层1:32通道,3x3尺寸,步长1的卷积核
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))
注意事项:这里是网络模型的第一层,因此要加上输入
池化层1:该池化层为2x2
layers.MaxPooling2D((2, 2))
卷积层2:64通道,3x3尺寸,步长1的卷积核
layers.Conv2D(64, (3, 3), activation='relu')
池化层2:该池化层为2x2
layers.MaxPooling2D((2, 2))
flatten:为拉伸层,也就是把之前的多维数据转为一维,因为其后面层为全连接层,而全连接层为一维数据
layers.Flatten()
全连接层1
layers.Dense(64, activation='relu')
全连接层2
layers.Dense(10)
重点:
现在我们来分析一下图片中经过每层后数据的维度怎么来的
经过卷积层1之后,原数据28x28变为26x26是因为一个公式: (28-2)/stride+1=26
经过池化层1之后,原数据26x26变为13x13是因为池化池的卷积核为2,所以13=26/2
经过卷积层2之后,原数据13变为11:如上,32变为64是因为此时卷积核通道数为64
经过flatten层之后,数据数量=5*5*64=1600
而后续全连接层的输出是根据全连接层代码设置。需要注意的是因为数据集是10种类型,因此最后为10
到此,我们便把网络模型设置的原因以及网络模型的输出结果进行了对应,我们可以看到网络模型的输出和我们分析的一致。
到此,网络模型我们变分析完了。
2.4
该部分也同样重要,主要完成模型训练过程中的优化器,损失函数,准确率的设置。
我们结合本文来看。
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
其中:对于这三个内容的含义可以参考我的文章开头的另外一篇基础博文进行了详细的介绍
2.5
该部分就是执行训练了,那么执行训练肯定需要设置训练集数据及其标签,测试集数据及其标签,训练的epoch
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
2.6
当训练执行完毕,我们就可以拿一个测试集合中或者其他满足格式的数据进行验证了,这里为了方便,我就用测试集合进行验证。
pre = model.predict(test_images) # 对所有测试图片进行预测
print( pre[1]) # 输出第一张图片的预测结果
边栏推荐
- 【二叉树】奇偶树
- 直播app开发,是优化直播体验不得不关注的两大指标
- Detailed explanation of the working principle of crystal oscillator
- DataTable Helper Class for C#
- 90后的焦虑,被菜市场治好了
- 访问域名直接访问wordpress
- OneFlow源码解析:Op、Kernel与解释器
- 金仓数据库 KingbaseES V8.3 至 V8.6 迁移最佳实践(4. V8.3 到 V8.6 数据库移植实战)
- Vulnhub靶机:HARRYPOTTER_ NAGINI
- 06 redis cluster structures
猜你喜欢
随机推荐
【Unity,C#】哨兵点位循迹模板代码
关于2022年深圳市福田区支持高端服务业发展项目的申报通知
【硬核拆解】50块2个的2022年夏季款智能节电器到底能不能省电?
表达式;运算符,算子;取余计算;运算符优先顺序
金仓数据库KingbaseES安全指南--6.4. RADIUS身份验证
统信软件、龙芯中科等四家企业共同发布《数字办公安全创新方案》
DevExpress的GridControl帮助类
JumpServer堡垒机部署
MLX90640 红外热成像仪测温模块开发笔记(完整版)
个人日记
MySQL's maximum recommended number of rows is 2000w, is it reliable?
LeetCode第 303 场周赛
11 一发布就发布一系列系列
完全背包问题求组合数和排列数
【R语言】线性混合模型进行重复测量设计分析
04 flink cluster construction
二分练习题
MySQL加锁案例分析
Bugku-Misc-贝斯手
04 flink 集群搭建