当前位置:网站首页>MNIST Handwritten Digit Recognition - Lenet-5's First Commercial Grade Convolutional Neural Network
MNIST Handwritten Digit Recognition - Lenet-5's First Commercial Grade Convolutional Neural Network
2022-08-04 06:19:00 【Learning Adventures】
MNISTThe reason why datasets are datasets for getting started with deep learning,是因为LeNet-5网络的诞生,The handwritten digit recognition effect of the network can reach the commercial level,It is the first deep neural network that is truly commercially available,It is widely used in the recognition of handwritten checks.
Convolutional neural networks are mostly used in image classification tasks,It uses a layered structure to extract features from images,It consists of a series of network layers stacked,比如卷积层、池化层、激活层等等.
本案例将使用Lenet-5来实现手写数字识别.
1. 加载并处理数据集
Because the convolutional neural network training requires more memory,The entire number cannot be identified6Thousands of samples are loaded for training at one time,因此需要分批加载训练集进行训练.
import os
import sys
import moxing as mox
datasets_dir = '../datasets'
if not os.path.exists(datasets_dir):
os.makedirs(datasets_dir)
if not os.path.exists(os.path.join(datasets_dir, 'MNIST_Data.zip')):
mox.file.copy('obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip',
os.path.join(datasets_dir, 'MNIST_Data.zip'))
os.system('cd %s; unzip MNIST_Data.zip' % (datasets_dir))
sys.path.insert(0, os.path.join(os.getcwd(), '../datasets/MNIST_Data'))
from load_data_all import load_data_all
from process_dataset import process_dataset
mnist_ds_train, mnist_ds_test, train_len, test_len = load_data_all(datasets_dir) # 加载数据集
mnist_ds_train = process_dataset(mnist_ds_train, batch_size= 32, resize= 32) # 处理训练集,分批加载
mnist_ds_test = process_dataset(mnist_ds_test, batch_size= 32, resize= 32) # 处理测试集,分批加载训练集规模:60000,测试集规模:10000
2. 构建LeNet-5Networks and Merit Functions
LeNet-5有5层网络,分别是卷积层1、卷积层2、全连接层1、全连接层2、全连接层3,网络结构如下图所示:

import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
class Network(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self, num_class=10, num_channel=1):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() # 输入到全连接层之前需要将16个5*5大小的特性矩阵拉成一个一维向量
# use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def evaluate(pred_y, true_y):
pred_labels = ops.Argmax(output_type=mindspore.int32)(pred_y)
correct_num = (pred_labels == true_y).asnumpy().sum().item()
return correct_num3. 定义交叉熵损失函数和优化器
# 损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 创建网络
network = Network(28*28)
lr = 0.01
momentum = 0.9
# 优化器
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)4. 实现训练函数
def train(network, mnist_ds_train, max_epochs= 50):
net = WithLossCell(network, net_loss)
net = TrainOneStepCell(net, net_opt)
network.set_train()
for epoch in range(1, max_epochs + 1):
train_correct_num = 0.0
test_correct_num = 0.0
for inputs_train in mnist_ds_train:
output = net(*inputs_train)
train_x = inputs_train[0]
train_y = inputs_train[1]
pred_y_train = network.construct(train_x) # 前向传播
train_correct_num += evaluate(pred_y_train, train_y)
train_acc = float(train_correct_num) / train_len
for inputs_test in mnist_ds_test:
test_x = inputs_test[0]
test_y = inputs_test[1]
pred_y_test = network.construct(test_x)
test_correct_num += evaluate(pred_y_test, test_y)
test_acc = float(test_correct_num) / test_len
if (epoch == 1) or (epoch % 10 == 0):
print("epoch: {0}/{1}, train_losses: {2:.4f}, tain_acc: {3:.4f}, test_acc: {4:.4f}".format(epoch, max_epochs, output.asnumpy(), train_acc, test_acc, cflush=True))5. 配置运行信息
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # device_target 可选 CPU/GPU, 当选择GPU时mindspore规格也需要切换到GPU6. 开始训练
训练50个epoch,耗时约11分钟
import time
from mindspore.nn import WithLossCell, TrainOneStepCell
max_epochs = 50
start_time = time.time()
print("*"*10 + "开始训练" + "*"*10)
train(network, mnist_ds_train, max_epochs= max_epochs)
print("*"*10 + "训练完成" + "*"*10)
cost_time = round(time.time() - start_time, 1)
print("训练总耗时: %.1f s" % cost_time)**********开始训练********** epoch: 1/50, train_losses: 2.3162, tain_acc: 0.1095, test_acc: 0.1027 epoch: 10/50, train_losses: 0.0008, tain_acc: 0.9942, test_acc: 0.9846 epoch: 20/50, train_losses: 0.0001, tain_acc: 0.9984, test_acc: 0.9832 epoch: 30/50, train_losses: 0.1701, tain_acc: 0.9996, test_acc: 0.9839 epoch: 40/50, train_losses: 0.0000, tain_acc: 1.0000, test_acc: 0.9861 epoch: 50/50, train_losses: 0.0000, tain_acc: 1.0000, test_acc: 0.9864 **********训练完成********** 训练总耗时: 686.4 s
可以看到,使用LeNet-5网络训练11分钟,50个epoch后,测试准确率达到98.6%
边栏推荐
- 计算某像素点法线
- 安装dlib踩坑记录,报错:WARNING: pip is configured with locations that require TLS/SSL
- 如何成长为高级工程师?
- 多层LSTM
- 关于DG(域泛化)领域的PCL方法的代码实例
- Android connects to mysql database using okhttp
- 0, deep learning 21 days learning challenge 】 【 set up learning environment
- Pytest常用插件
- The pipeline mechanism in sklearn
- 【CV-Learning】Object Detection & Instance Segmentation
猜你喜欢

TensorFlow2 study notes: 6. Overfitting and underfitting, and their mitigation solutions

YOLOV5 V6.1 详细训练方法

【CV-Learning】Image Classification

Amazon Cloud Technology Build On 2022 - AIot Season 2 IoT Special Experiment Experience

数据库的简述与常用操作指南

Halcon缺陷检测

度量学习(Metric learning、损失函数、triplet、三元组损失、fastreid)

在AWS-EC2中安装Minikube集群

【CV-Learning】语义分割

腾讯、网易纷纷出手,火到出圈的元宇宙到底是个啥?
随机推荐
【CV-Learning】Image Classification
亚马逊云科技Build On-Amazon Neptune基于知识图谱的推荐模型构建心得
光条提取中的连通域筛除
【CV-Learning】卷积神经网络预备知识
度量学习(Metric learning、损失函数、triplet、三元组损失、fastreid)
WARNING: sql version 9.2, server version 11.0. Some psql features might not work.
Copy Siege Lion 5-minute online experience MindIR format model generation
Deep Adversarial Decomposition: A Unified Framework for Separating Superimposed Images
MOOSE平台官方第二个例子分析——关于创建Kernel,求解对流扩散方程
关于DG(域泛化)领域的PCL方法的代码实例
Copy攻城狮信手”粘“来 AI 对对联
图像线性融合
Briefly say Q-Q map; stats.probplot (QQ map)
AIDL communication between two APPs
Usage of Thread, Handler and IntentService
浅谈游戏音效测试点
MOOSE平台使用入门攻略——如何运行官方教程的例子
"A minute" Copy siege lion log 】 【 run MindSpore LeNet model
MNIST手写数字识别 —— 图像分析法实现二分类
0, deep learning 21 days learning challenge 】 【 set up learning environment