当前位置:网站首页>MNIST手写数字识别 —— 从感知机到卷积神经网络
MNIST手写数字识别 —— 从感知机到卷积神经网络
2022-08-04 05:30:00 【学习历险记】
更换模型网络结构的方法,卷积神经网络的构建方法。
上一节使用十个输出节点的感知机模型实现了手写数字识别,但是在训练了100个epoch之后,也仅仅达到0.8037的准确率,如果尝试调整max_epochs、损失函数、梯度下降方法或学习率,会发现准确率还是难以上去。用一句俗话来说,就是底层逻辑不变,只做一些表面功夫,始终难有较大的提升。 这种情况下,或许就可以考虑更换模型的底层逻辑——网络结构了。
本案例将使用CNN来实现。
1. 加载并处理数据集
复用上一节保存的load_data_all函数和process_dataset函数,加载处理全量的手写数字识别数据集。
import os
import sys
import moxing as mox
datasets_dir = '../datasets'
if not os.path.exists(datasets_dir):
os.makedirs(datasets_dir)
if not os.path.exists(os.path.join(datasets_dir, 'MNIST_Data.zip')):
mox.file.copy('obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip',
os.path.join(datasets_dir, 'MNIST_Data.zip'))
os.system('cd %s; unzip MNIST_Data.zip' % (datasets_dir))
sys.path.insert(0, os.path.join(os.getcwd(), '../datasets/MNIST_Data'))
from load_data_all import load_data_all
from process_dataset import process_dataset
mnist_ds_train, mnist_ds_test, train_len, test_len = load_data_all(datasets_dir) # 加载数据集
mnist_ds_train = process_dataset(mnist_ds_train, batch_size= 60000) # 处理训练集
mnist_ds_test = process_dataset(mnist_ds_test, batch_size= 10000) # 处理测试集训练集规模:60000,测试集规模:10000
2. 构建CNN网络和评价函数
评价函数直接复用了上一节的代码,网络结构部分就需要做较大的调整,代码如下,具体含义请查看代码注释。
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
class Network(nn.Cell):
"""
该网络只有三层网络,分别是卷积层1、卷积层2和全连接层1,ReLU和MaxPool2d由于不带参数,所以不计入网络层数
"""
def __init__(self, num_of_weights):
super(Network, self).__init__()
# Convolution 1
self.conv1 = nn.Conv2d(in_channels=1, out_channels=16,
kernel_size=5, pad_mode='valid',
stride=1, padding=0) # 卷积层1,输入为1个通道,输出为16个通道,卷积核大小为5,滑动步长为1,不做边缘填充
# Convolution 2
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32,
kernel_size=5, pad_mode='valid',
stride=1, padding=0) # 卷积层2,输入为16个通道,输出为32个通道,卷积核大小为5,滑动步长为1,不做边缘填充
# Fully connected
self.fc = nn.Dense(32 * 4 * 4, 10, weight_init= Normal(0.02)) # 全连接层1,输入维度为32*4*4,输出维度为10
self.relu = nn.ReLU() # 激活层,使用卷积网络中最常用的ReLU激活函数
self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2) # 最大池化层
self.flatten = nn.Flatten()
def construct(self, x):
"""
前向传播函数
"""
# Convolution 1
out = self.conv1(x) # 卷积
out = self.relu(out) # 激活
out = self.maxpool(out) # 池化
# Convolution 2
out = self.conv2(out) # 卷积
out = self.relu(out) # 激活
out = self.maxpool(out) # 池化
# Fully connected 1
# out = out.view(out.size(0), -1) # 输入到全连接层之前需要将32个4*4大小的特性矩阵拉成一个一维向量
out = self.flatten(out)
out = self.fc(out) # 计算全连接层
return out
def evaluate(pred_y, true_y):
pred_labels = ops.Argmax(output_type=mindspore.int32)(pred_y)
correct_num = (pred_labels == true_y).asnumpy().sum().item()
return correct_num3. 定义交叉熵损失函数和优化器
复用上一节
# 损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 创建网络
network = Network(28*28)
lr = 0.01
momentum = 0.9
# 优化器
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)4. 实现训练函数
复用上一节
def train(network, mnist_ds_train, max_epochs= 50):
net = WithLossCell(network, net_loss)
net = TrainOneStepCell(net, net_opt)
network.set_train()
for epoch in range(1, max_epochs + 1):
train_correct_num = 0.0
test_correct_num = 0.0
for inputs_train in mnist_ds_train:
output = net(*inputs_train)
train_x = inputs_train[0]
train_y = inputs_train[1]
pred_y_train = network.construct(train_x) # 前向传播
train_correct_num += evaluate(pred_y_train, train_y)
train_acc = float(train_correct_num) / train_len
for inputs_test in mnist_ds_test:
test_x = inputs_test[0]
test_y = inputs_test[1]
pred_y_test = network.construct(test_x)
test_correct_num += evaluate(pred_y_test, test_y)
test_acc = float(test_correct_num) / test_len
if (epoch == 1) or (epoch % 10 == 0):
print("epoch: {0}/{1}, train_losses: {2:.4f}, tain_acc: {3:.4f}, test_acc: {4:.4f}".format(epoch, max_epochs, output.asnumpy(), train_acc, test_acc, cflush=True))5. 配置运行信息
此处选用硬件规格为GPU
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # 当选择GPU时mindspore规格也需要切换到GPU6. 开始训练
import time
from mindspore.nn import WithLossCell, TrainOneStepCell
max_epochs = 100
start_time = time.time()
print("*"*10 + "开始训练" + "*"*10)
train(network, mnist_ds_train, max_epochs= max_epochs)
print("*"*10 + "训练完成" + "*"*10)
cost_time = round(time.time() - start_time, 1)
print("训练总耗时: %.1f s" % cost_time)**********开始训练********** epoch: 1/100, train_losses: 2.3024, tain_acc: 0.1307, test_acc: 0.1312 epoch: 10/100, train_losses: 2.3004, tain_acc: 0.1986, test_acc: 0.1980 epoch: 20/100, train_losses: 2.2937, tain_acc: 0.3035, test_acc: 0.3098 epoch: 30/100, train_losses: 2.2683, tain_acc: 0.3669, test_acc: 0.3754 epoch: 40/100, train_losses: 2.1102, tain_acc: 0.4212, test_acc: 0.4290 epoch: 50/100, train_losses: 1.0519, tain_acc: 0.7415, test_acc: 0.7551 epoch: 60/100, train_losses: 1.3377, tain_acc: 0.7131, test_acc: 0.7190 epoch: 70/100, train_losses: 0.9068, tain_acc: 0.7817, test_acc: 0.7888 epoch: 80/100, train_losses: 0.4193, tain_acc: 0.8732, test_acc: 0.8843 epoch: 90/100, train_losses: 0.3339, tain_acc: 0.9000, test_acc: 0.9069 epoch: 100/100, train_losses: 0.2796, tain_acc: 0.9177, test_acc: 0.9219 **********训练完成********** 训练总耗时: 261.2 s
从上面输出可以看出,使用lenet网络,训练同样的批次,准确率达到了92%,有了不小的提升。
边栏推荐
- Comparison of oracle's number and postgresql's numeric
- 图像合并水平拼接
- Image-Adaptive YOLO for Object Detection in Adverse Weather Conditions
- 计算某像素点法线
- YOLOV4流程图(方便理解)
- Usage of Thread, Handler and IntentService
- TensorFlow2 study notes: 8. tf.keras implements linear regression, Income dataset: years of education and income dataset
- Deep Adversarial Decomposition: A Unified Framework for Separating Superimposed Images
- Halcon缺陷检测
- 简单说Q-Q图;stats.probplot(QQ图)
猜你喜欢

Vision Transformer 论文 + 详解( ViT )

简单明了,数据库设计三大范式

动手学深度学习_多层感知机
![[Deep Learning 21-Day Learning Challenge] 3. Use a self-made dataset - Convolutional Neural Network (CNN) Weather Recognition](/img/d0/3b8549b9704278e8ec1df03a90f80e.png)
[Deep Learning 21-Day Learning Challenge] 3. Use a self-made dataset - Convolutional Neural Network (CNN) Weather Recognition

Attention Is All You Need(Transformer)

tensorRT教程——tensor RT OP理解(实现自定义层,搭建网络)

【CV-Learning】Image Classification

MFC读取点云,只能正常显示第一个,显示后面时报错

Lee‘s way of Deep Learning 深度学习笔记

腾讯、网易纷纷出手,火到出圈的元宇宙到底是个啥?
随机推荐
Dictionary feature extraction, text feature extraction.
动手学深度学习__数据操作
浅谈外挂常识和如何防御
【论文阅读】Exploring Spatial Significance via Hybrid Pyramidal Graph Network for Vehicle Re-identificatio
ConnectionRefusedError: [Errno 111] Connection refused问题解决
【论文阅读】Anchor-Free Person Search
【CV-Learning】语义分割
【CV-Learning】卷积神经网络预备知识
浅谈游戏音效测试点
详解近端策略优化
ValueError: Expected 96 from C header, got 88 from PyObject
The difference between oracle temporary table and pg temporary table
强化学习中,Q-Learning与Sarsa的差别有多大?
语音驱动嘴型与面部动画生成的现状和趋势
安装dlib踩坑记录,报错:WARNING: pip is configured with locations that require TLS/SSL
sklearn中的pipeline机制
【深度学习21天学习挑战赛】1、我的手写被模型成功识别——CNN实现mnist手写数字识别模型学习笔记
【论文阅读】Mining Cross-Image Semantics for Weakly Supervised Semantic Segmentation
打金?工作室?账号被封?游戏灰黑产离我们有多近
TensorFlow2 study notes: 7. Optimizer