当前位置:网站首页>MNIST Handwritten Digit Recognition - Lenet-5's First Commercial Grade Convolutional Neural Network
MNIST Handwritten Digit Recognition - Lenet-5's First Commercial Grade Convolutional Neural Network
2022-08-04 06:19:00 【Learning Adventures】
MNISTThe reason why datasets are datasets for getting started with deep learning,是因为LeNet-5网络的诞生,The handwritten digit recognition effect of the network can reach the commercial level,It is the first deep neural network that is truly commercially available,It is widely used in the recognition of handwritten checks.
Convolutional neural networks are mostly used in image classification tasks,It uses a layered structure to extract features from images,It consists of a series of network layers stacked,比如卷积层、池化层、激活层等等.
本案例将使用Lenet-5来实现手写数字识别.
1. 加载并处理数据集
Because the convolutional neural network training requires more memory,The entire number cannot be identified6Thousands of samples are loaded for training at one time,因此需要分批加载训练集进行训练.
import os
import sys
import moxing as mox
datasets_dir = '../datasets'
if not os.path.exists(datasets_dir):
os.makedirs(datasets_dir)
if not os.path.exists(os.path.join(datasets_dir, 'MNIST_Data.zip')):
mox.file.copy('obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip',
os.path.join(datasets_dir, 'MNIST_Data.zip'))
os.system('cd %s; unzip MNIST_Data.zip' % (datasets_dir))
sys.path.insert(0, os.path.join(os.getcwd(), '../datasets/MNIST_Data'))
from load_data_all import load_data_all
from process_dataset import process_dataset
mnist_ds_train, mnist_ds_test, train_len, test_len = load_data_all(datasets_dir) # 加载数据集
mnist_ds_train = process_dataset(mnist_ds_train, batch_size= 32, resize= 32) # 处理训练集,分批加载
mnist_ds_test = process_dataset(mnist_ds_test, batch_size= 32, resize= 32) # 处理测试集,分批加载
训练集规模:60000,测试集规模:10000
2. 构建LeNet-5Networks and Merit Functions
LeNet-5有5层网络,分别是卷积层1、卷积层2、全连接层1、全连接层2、全连接层3,网络结构如下图所示:
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops
from mindspore.common.initializer import Normal
class Network(nn.Cell):
"""Lenet network structure."""
# define the operator required
def __init__(self, num_class=10, num_channel=1):
super(Network, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten() # 输入到全连接层之前需要将16个5*5大小的特性矩阵拉成一个一维向量
# use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
def evaluate(pred_y, true_y):
pred_labels = ops.Argmax(output_type=mindspore.int32)(pred_y)
correct_num = (pred_labels == true_y).asnumpy().sum().item()
return correct_num
3. 定义交叉熵损失函数和优化器
# 损失函数
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
# 创建网络
network = Network(28*28)
lr = 0.01
momentum = 0.9
# 优化器
net_opt = nn.Momentum(network.trainable_params(), lr, momentum)
4. 实现训练函数
def train(network, mnist_ds_train, max_epochs= 50):
net = WithLossCell(network, net_loss)
net = TrainOneStepCell(net, net_opt)
network.set_train()
for epoch in range(1, max_epochs + 1):
train_correct_num = 0.0
test_correct_num = 0.0
for inputs_train in mnist_ds_train:
output = net(*inputs_train)
train_x = inputs_train[0]
train_y = inputs_train[1]
pred_y_train = network.construct(train_x) # 前向传播
train_correct_num += evaluate(pred_y_train, train_y)
train_acc = float(train_correct_num) / train_len
for inputs_test in mnist_ds_test:
test_x = inputs_test[0]
test_y = inputs_test[1]
pred_y_test = network.construct(test_x)
test_correct_num += evaluate(pred_y_test, test_y)
test_acc = float(test_correct_num) / test_len
if (epoch == 1) or (epoch % 10 == 0):
print("epoch: {0}/{1}, train_losses: {2:.4f}, tain_acc: {3:.4f}, test_acc: {4:.4f}".format(epoch, max_epochs, output.asnumpy(), train_acc, test_acc, cflush=True))
5. 配置运行信息
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") # device_target 可选 CPU/GPU, 当选择GPU时mindspore规格也需要切换到GPU
6. 开始训练
训练50个epoch,耗时约11分钟
import time
from mindspore.nn import WithLossCell, TrainOneStepCell
max_epochs = 50
start_time = time.time()
print("*"*10 + "开始训练" + "*"*10)
train(network, mnist_ds_train, max_epochs= max_epochs)
print("*"*10 + "训练完成" + "*"*10)
cost_time = round(time.time() - start_time, 1)
print("训练总耗时: %.1f s" % cost_time)
**********开始训练********** epoch: 1/50, train_losses: 2.3162, tain_acc: 0.1095, test_acc: 0.1027 epoch: 10/50, train_losses: 0.0008, tain_acc: 0.9942, test_acc: 0.9846 epoch: 20/50, train_losses: 0.0001, tain_acc: 0.9984, test_acc: 0.9832 epoch: 30/50, train_losses: 0.1701, tain_acc: 0.9996, test_acc: 0.9839 epoch: 40/50, train_losses: 0.0000, tain_acc: 1.0000, test_acc: 0.9861 epoch: 50/50, train_losses: 0.0000, tain_acc: 1.0000, test_acc: 0.9864 **********训练完成********** 训练总耗时: 686.4 s
可以看到,使用LeNet-5网络训练11分钟,50个epoch后,测试准确率达到98.6%
边栏推荐
- 【CV-Learning】卷积神经网络预备知识
- MNIST手写数字识别 —— 图像分析法实现二分类
- tensorRT5.15 使用中的注意点
- 如何成长为高级工程师?
- Th in thymeleaf: href use notes
- Qt日常学习
- Amazon Cloud Technology Build On 2022 - AIot Season 2 IoT Special Experiment Experience
- TensorFlow2 study notes: 6. Overfitting and underfitting, and their mitigation solutions
- [CV-Learning] Linear Classifier (SVM Basics)
- Copy攻城狮5分钟在线体验 MindIR 格式模型生成
猜你喜欢
MNIST手写数字识别 —— ResNet-经典卷积神经网络
【论文阅读】Multi-View Spectral Clustering with Optimal Neighborhood Laplacian Matrix
强化学习中,Q-Learning与Sarsa的差别有多大?
Pytest常用插件
Copy Siege Lions "sticky" to AI couplets
基于PyTorch的FCN-8s语义分割模型搭建
DeblurGAN-v2: Deblurring (Orders-of-Magnitude) Faster and Better 图像去模糊
The second official example analysis of the MOOSE platform - about creating a Kernel and solving the convection-diffusion equation
Dictionary feature extraction, text feature extraction.
Copy攻城狮5分钟在线体验 MindIR 格式模型生成
随机推荐
Data reading in yolov3 (1)
亚马逊云科技 Build On 2022 - AIot 第二季物联网专场实验心得
AIDL communication between two APPs
TensorFlow2 study notes: 7. Optimizer
Dictionary feature extraction, text feature extraction.
图像合并水平拼接
【论文阅读】Multi-View Spectral Clustering with Optimal Neighborhood Laplacian Matrix
Copy攻城狮5分钟在线体验 MindIR 格式模型生成
【深度学习日记】第一天:Hello world,Hello CNN MNIST
深度学习理论——过拟合、欠拟合、正则化、优化器
光条中心提取方法总结(二)
WARNING: sql version 9.2, server version 11.0. Some psql features might not work.
tensorRT教程——使用tensorRT OP 搭建自己的网络
Golang环境变量设置(二)--GOMODULE&GOPROXY
ValueError: Expected 96 from C header, got 88 from PyObject
强化学习中,Q-Learning与Sarsa的差别有多大?
【CV-Learning】线性分类器(SVM基础)
[Introduction to go language] 12. Pointer
SQL注入详解
Attention Is All You Need(Transformer)