当前位置:网站首页>MNIST手写数字识别 —— ResNet-经典卷积神经网络
MNIST手写数字识别 —— ResNet-经典卷积神经网络
2022-08-04 05:30:00 【学习历险记】
了解ResNet18的网络结构;掌握模型的保存和加载方法;掌握批量测试图片的方法。
结合图像分类任务,使用典型的图像分类网络ResNet18,实现手写数字识别。
ResNet作为经典的图像分类网络有其明显的优点:
首先,它足够深,常见的有34层,50层,101层。通常层次越深,表征能力越强,分类准确率越高。
其次,可学习,采用了残差结构,通过shortcut连接把低层直接跟高层相连,解决了反向传播过程中因为网络太深造成的梯度消失问题。
此外,ResNet网络的性能很好,既表现为识别的准确率,也包括它本身模型的大小和参数量。
1. 加载并处理数据集
import os
import sys
import moxing as mox
datasets_dir = '../datasets'
if not os.path.exists(datasets_dir):
os.makedirs(datasets_dir)
if not os.path.exists(os.path.join(datasets_dir, 'MNIST_Data.zip')):
mox.file.copy('obs://modelarts-labs-bj4-v2/course/hwc_edu/python_module_framework/datasets/mindspore_data/MNIST_Data.zip',
os.path.join(datasets_dir, 'MNIST_Data.zip'))
os.system('cd %s; unzip MNIST_Data.zip' % (datasets_dir))
sys.path.insert(0, os.path.join(os.getcwd(), '../datasets/MNIST_Data'))
from load_data_all import load_data_all
from process_dataset import process_dataset
mnist_ds_train, mnist_ds_test, train_len, test_len = load_data_all(datasets_dir) # 加载数据集
mnist_ds_train = process_dataset(mnist_ds_train, batch_size= 64, resize= 28) # 处理训练集,分批加载
mnist_ds_test = process_dataset(mnist_ds_test, batch_size= 32, resize= 28) # 处理测试集, 分批加载
训练集规模:60000,测试集规模:10000
2. 下载构建好的resnet18网络源码文件
为了让开发者更好地体验MindSpore框架优势,MindSpore Model Zoo已经添加了更多典型网络和相关预训练模型,涉及到计算机视觉、自然语言处理、推荐系统、图神经网络等领域。其中的ResNet系列网络模型也已经使用MindSpore实现。
2.1. 下载网络源码文件
采用ResNet-18 实现手写数字识别任务,需要将resnet.py
下载下来,才可以使用MindSpore定义好的网络结构。
# 下载构建好的网络源文件,只需执行一次即可。
!wget -N https://modelarts-labs-bj4-v2.obs.cn-north-4.myhuaweicloud.com/course/mindspore/mnist_recognition/src/resnet.py --no-check-certificate
2.2.修改网络文件中通道数
由于使用的是单通道的灰度图像数据,这里需要将原网络结果中的第一层卷积层的输入通道 3改为1,即将resnet.py
中的第387行的 3 改为 1 即可,这里使用linux中的sed命令来编辑文本。
# 此命令执行后将直接在原文件中修改,不会有任何输出
!sed -i '387s/3/1/g' ./resnet.py
3. 载入resnet18网络
from resnet import resnet18
network = resnet18(class_num=10)
4. 定义损失函数和优化器
import mindspore
import mindspore.nn as nn
lr = 0.01
momentum = 0.9
net_loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') # 损失函数
net_opt = nn.Momentum(network.trainable_params(), lr, momentum) # 优化器
5. 配置运行信息
from mindspore import context
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
6. 调用Model高阶API进行训练和保存模型文件
模型训练包含两层迭代,数据集的多轮迭代(epoch)和一轮数据集内按分组(batch)大小进行的单步迭代。
为了简化训练过程,MindSpore封装了Model高阶接口:
用户输入网络、损失函数和优化器完成Model的初始化;
调用train接口进行训练,train接口参数包括迭代次数(epoch)和数据集(dataset);
调用Model的eval接口预测新图像类别;
模型保存是对训练参数进行持久化的过程。Model类中通过回调函数(callback)的方式进行模型保存。
这里对数据进行一轮迭代,训练耗时约30秒
import os,time
from mindspore import Model
from mindspore import load_checkpoint, load_param_into_net
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
model = Model(network, loss_fn=net_loss, optimizer=net_opt, metrics={
'acc'}) # 完成Model初始化
# 训练参数
batch_num = mnist_ds_train.get_dataset_size()
max_epochs = 1
model_path = "./models/ckpt/"
os.system('rm -f {0}*.ckpt {0}*.meta {0}*.pb'.format(model_path))
# 定义回调函数
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num, keep_checkpoint_max=35)
ckpoint_cb = ModelCheckpoint(prefix="train_resnet_mnist", directory=model_path, config=config_ck)
loss_cb = LossMonitor(batch_num) # 用于输出损失
start_time = time.time()
model.train(max_epochs, mnist_ds_train, callbacks=[ckpoint_cb, loss_cb]) # 训练
res = model.eval(mnist_ds_test) # 验证测试集
print("result: ", res)
cost_time = time.time() - start_time
print("训练总耗时: %.1f s" % cost_time)
epoch: 1 step: 937, loss is 0.039122876 result: {'acc': 0.9818709935897436} 训练总耗时: 16.7 s
从上面的输出结果,可以看到ResNet18模型仅训练一个epoch,耗时仅17秒左右,就在手写数字识别任务的测试集上达到了0.98以上的准确率。该准确率可以达到应用水平,下面将保存模型、并加载模型进行批量图片预测,看看真实的预测效果如何。
查询训练过程中,保存好的模型
!tree ./models/ckpt/
./models/ckpt/ ├── train_resnet_mnist-1_937.ckpt └── train_resnet_mnist-graph.meta 0 directories, 2 files
每937个step保存一次模型权重参数.ckpt文件,一共保存了1个,另外.meta文件保存模型的计算图信息。
7. 批量图片的预测显示
import numpy as np
from PIL import Image
import mindspore
import mindspore.ops as ops
from mindspore import Tensor
dic_ds_test = mnist_ds_test.create_dict_iterator(output_numpy=True)
ds_test = next(dic_ds_test)
images_test = ds_test["image"]
labels_test = ds_test["label"]
output = model.predict(Tensor(images_test))
pred_labels = ops.Argmax(output_type=mindspore.int32)(output)
print("预测值 -- > ", pred_labels) # 打印预测值
print("真实值 -- > ", labels_test) # 打印真实值
batch_img = np.squeeze(images_test[0])
for i in range(1, len(labels_test)):
batch_img = np.hstack((batch_img, np.squeeze(images_test[i]))) # 将一批图片水平拼接起来,方便下一步进行显示
Image.fromarray((batch_img*255).astype('uint8'), mode= "L") # 显示真实值
预测值 -- > [0 3 8 5 2 0 1 2 8 5 0 3 4 5 4 5 6 2 9 4 8 0 1 1 7 5 6 7 8 5 9 4] 真实值 -- > [0 3 1 5 2 0 1 2 8 5 0 3 4 5 4 5 6 2 4 4 8 0 1 1 7 5 6 7 8 5 9 4]
从预测结果可以看出,32张图片只有一张错误预测:1预测成了8,说明模型的预测效果还是不错的。
边栏推荐
猜你喜欢
fuser 使用—— YOLOV5内存溢出——kill nvidai-smi 无pid 的 GPU 进程
安卓连接mysql数据库,使用okhttp
PP-LiteSeg
fill_between in Matplotlib; np.argsort() function
【深度学习21天学习挑战赛】3、使用自制数据集——卷积神经网络(CNN)天气识别
【CV-Learning】Convolutional Neural Network
【CV-Learning】卷积神经网络预备知识
【CV-Learning】卷积神经网络
安装dlib踩坑记录,报错:WARNING: pip is configured with locations that require TLS/SSL
[CV-Learning] Semantic Segmentation
随机推荐
PCL窗口操作
ValueError: Expected 96 from C header, got 88 from PyObject
PostgreSQL schema (Schema)
sklearn中的pipeline机制
[Deep Learning 21-Day Learning Challenge] 3. Use a self-made dataset - Convolutional Neural Network (CNN) Weather Recognition
Qt日常学习
thymeleaf中 th:href使用笔记
PCL1.12 解决memory.h中EIGEN处中断问题
强化学习中,Q-Learning与Sarsa的差别有多大?
【go语言入门笔记】12、指针
Endnote编辑参考文献
逻辑回归---简介、API简介、案例:癌症分类预测、分类评估法以及ROC曲线和AUC指标
【代码学习】
图像合并水平拼接
【深度学习21天学习挑战赛】0、搭建学习环境
度量学习(Metric learning)—— 基于分类损失函数(softmax、交叉熵、cosface、arcface)
多项式回归(PolynomialFeatures)
TensorFlow2 study notes: 7. Optimizer
Polynomial Regression (PolynomialFeatures)
简单说Q-Q图;stats.probplot(QQ图)