当前位置:网站首页>Deep Learning 100 Examples - Convolutional Neural Network (CNN) for mnist handwritten digit recognition
Deep Learning 100 Examples - Convolutional Neural Network (CNN) for mnist handwritten digit recognition
2022-08-02 10:40:00 【Ding Jiaxiong】
活动地址:CSDN21天学习挑战赛
深度学习100例 —— 卷积神经网络(CNN)实现mnist手写数字识别
K老师用的tensorflow框架,啊这,Create a new virtual environment.
Python:3.7
编译器:jupyter notebook
深度学习框架:TensorFlow2.7.0


1. 前期准备工作
1.1 设置GPU
Here I finally chose to use itAutoDL的算力,I hate to run on my computer.
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU") # tf.config.list_physical_devices# Get a specific computing device type on the current host(如 GPU 或 CPU )的列表
if gpus:
gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpu0],"GPU") # Set the list of visible devices

1.2 导入数据
import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()
MNISTThe handwritten digits dataset is from the National Institute of Standards and Technology,is one of the well-known public datasets.The digital pictures in the dataset are made by 250个不同职业的人纯手写绘制.
MNISTIncluded in the handwritten digits dataset70000张图片,其中60000Zhang is the training data,10000Zhang is the test data,70000pictures are28*28

28 x 28 Convert pixels to vectors,得到长度为28 x 28 = 784 的向量,

1.3 数据归一化处理
作用:
- 使不同量纲的特征处于同一数值量级,减少方差大的特征的影响,使模型更准确.
- 加快学习算法的收敛速度.
归一化:把数变为(0 , 1)之间的小数;Scaling only follows the maximum value、最小值的差别有关.
# 将像素的值标准化至0到1的区间内.
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape


1.4 可视化图片
plt.figure(figsize=(20,10)) # 设置绘图窗口大小
for i in range(20):
plt.subplot(5,10,i+1) # 整个figure被分为5行10列,从左到右,从上到下编号为1,2,3...
plt.xticks([]) # 绘制x轴标签的方向
plt.yticks([]) # 绘制y轴标签的方向
plt.grid(False) # Turn off gridline display
plt.imshow(train_images[i], cmap=plt.cm.binary) # cmap=plt.cm.binary → 显示黑白图像,展示0-19的图像
plt.xlabel(train_labels[i]) # 设置xThe axis labels are the corresponding label values
plt.show()

1.5 调整图片格式
#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
→ 数据维度转换,增加通道维度

2. 构建CNN网络模型
2.1 构建网络
model = models.Sequential(
[
layers.Conv2D(32 , (3 , 3) , activation = 'relu', input_shape = [28 , 28 , 1]), # 卷积层1,卷积核3 x 3
layers.MaxPooling2D((2, 2)), # 池化层1 , 2 x 2 采样
layers.Conv2D(64 , (3 , 3) , activation = 'relu'), # 卷积层2 , 卷积核3 x 3
layers.MaxPooling2D((2, 2)), # 池化层2, 2 x 2 采样
layers.Flatten(), # Flatten层,连接卷积层与全连接层
layers.Dense(64 , activation = 'relu'), # 全连接层,特征进一步提取
layers.Dense(10) # 输出层,输出结果
]
)
model.summary()
网络结构

2.2 模型结构说明:

各层的作用:
- 输入层:Feed the data into the training network
- 卷积层:Extract features from images using convolution kernels
- 池化层:进行下采样,Represent image features with a higher level of abstraction
- Flatten层:将多维的输入一维化,Commonly used in the transition from convolutional layers to fully connected layers
- 全连接层:特征提取器
- 输出层:输出结果
3. 编译模型
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
compile()method is used to set training time,使用的优化器optimizer、损失函数loss、准确率评测标准metrics
SparseCategoricalCrossentropy → 交叉熵损失函数,当from_logits参数为True时,会使用softmax将预测y转换为概率.
4. 训练模型
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
epoch : 迭代次数,All samples will be trained loop10次
validation_data:指定验证数据.

The accuracy of the compiled model settings,Here is the output accuracy
5. 预测
Take a look at the test set0号图片

OK,是个7
使用训练好的模型进行预测:

边栏推荐
猜你喜欢

FPGA手撕代码——CRC校验码的多种Verilog实现方式 (2021乐鑫科技数字IC提前批代码编程)

STM32+MPU6050设计便携式Mini桌面时钟(自动调整时间显示方向)

Spearman's correlation coefficient

超赞!发现一个APP逆向神器!

Jay Chou's new song is released, crawl the "Mojito" MV barrage, and see what the fans have to say!

Why use BGP?

Rear tube implements breadcrumb function

Hello, my new name is "Bronze Lock/Tongsuo"

字节跳动软件测试岗,收到offer后我却拒绝了~给面试的人一些忠告....

LeetCode每日一练 —— 225. 用队列实现栈
随机推荐
记一次mysql查询慢的优化历程
sqlmap安装教程用w+r打开(sqlyog安装步骤)
Geoffery Hinton:深度学习的下一个大事件
LayaBox---TypeScript---声明合并
MySql千万级分页优化,快速插入千万数据方法
38岁女儿不恋爱没有稳定工作老母亲愁哭
众城优选系统开发功能
软件工程国考总结——选择题
Shell脚本实现多选DNS同时批量解析域名IP地址(新更新)
win10打印服务无法启动(运行时错误automation)
利用二维数据学习纹理三维网格生成(CVPR 2020)
qq邮箱日发5万邮件群发技术(qq邮箱怎样定时发送邮件)
How to choose a truly "easy-to-use, high-performance" remote control software
logo 图标(php图片加文字水印)
R language ggplot2 visualization: use the ggbarplot function of the ggpubr package to visualize the horizontal column chart (bar chart), use the orientation parameter to set the column chart to be tra
DVWA Clearance Log 2 - Command Injection
idea常用插件
全新荣威RX5,27寸大屏吸引人,安全、舒适一个不落
外包学生管理系统架构文档
Long battery life or safer?Seal and dark blue SL03 comparison shopping guide