当前位置:网站首页>pytorch应用于MNIST手写字体识别
pytorch应用于MNIST手写字体识别
2022-08-04 01:59:00 【windawdaysss】
前言
手写字体MNIST数据集是一组常见的图像,其常用于测评和比较机器学习算法的性能,本文使用pytorch框架来实现对该数据集的识别,并对结果进行逐步的优化。
一、数据集
MNIST数据集是由28x28大小的0-255像素值范围的灰度图像(如下图所示),其中6万张用于训练模型,1万张用于测试模型。
该数据集可从以下链接获取:
训练数据集:
https://pjreddie.com/media/files/mnist_train.csv
测试数据集:
https://pjreddie.com/media/files/mnist_test.csv
数据集一行有785个值,第一个值为图像中的数字标签,其余784个值为图像的像素值。
读取数据实例代码如下:
import pandas
import matplotlib.pyplot as plt
df = pandas.read_csv(r'./data/mnist_train.csv', header=None)
# print(df.head()) # 显示前5行
# print(df.info()) # 显示DataFrame概况
row = 0
data = df.iloc[row]
label = data[0],
img = data[1:].values.reshape(28, 28)
plt.title('label = ' + str(label))
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()

二、建立模型
# 构建模型
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class Classifier(nn.Module):
def __init__(self):
# 初始化pytorch父类
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.Sigmoid(),
nn.Linear(200, 10),
nn.Sigmoid()
)
self.loss_function = nn.MSELoss()
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train_model(self, inputs, targets):
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
self.optimizer.zero_grad() # 梯度归零 ,因为反向传播计算的梯度会累计
loss.backward() # 反向传播
self.optimizer.step() # 更新权重
# 可视化训练过程
self.counter += 1
if self.counter % 10 == 0:
self.progress.append(loss.item()) # 获取单张张量里的数字
pass
if self.counter % 10000 == 0:
print('counter = ', self.counter)
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
plt.show()
pass
class MnistDataset(Dataset):
def __init__(self, csv_file):
self.data_df = pandas.read_csv(csv_file, header=None)
pass
def __len__(self):
return len(self.data_df)
def __getitem__(self, index):
label = self.data_df.iloc[index, 0]
target = torch.zeros((10))
target[label] = 1
image_value = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0
return label, image_value, target
def plot_image(self, index):
arr = self.data_df.iloc[index, 1:].values.reshape(28, 28)
plt.title('label = ' + str(self.data_df.iloc[index, 0]))
plt.imshow(arr, interpolation='none', cmap='Blues')
plt.show()
pass
pass
以上建立模型框架,并对训练过程进行可视化,建立读取数据类。
三、训练分类模型
mnist_train_dataset = MnistDataset(r'./data/mnist_train.csv')
# mnist_dataset.plot_image(9)
# 训练分类模型
start_time = time.time()
C = Classifier()
epochs = 3 # 训练3轮
for i in range(epochs):
print('training epoch ', i+1, 'of', epochs)
for lable, image_tensor, target_tensor in mnist_train_dataset:
C.train_model(image_tensor, target_tensor)
pass
pass
C.plot_process()
print('run time = ', (time.time()-start_time) / 60)
训练3轮所花费的时间大约不到3min,效率还不错
四、测试模型
# 测试模型
mnist_test_dataset = MnistDataset(r'./data/mnist_test.csv')
record = 19
mnist_test_dataset.plot_image(record) # 图像里的数字
image_data = mnist_test_dataset[record][1]
output = C.forward(image_data)
pandas.DataFrame(output.detach().numpy()).plot(kind='bar', legend=False, ylim=(0, 1)) # 预测的数字
plt.show()
score = 0
items = 0
for label, img_tensor, label_tensor in mnist_test_dataset:
ans = C.forward(img_tensor)
if ans.argmax() == label:
score += 1
pass
items += 1
pass
print(score, items, score / items)
模型的测试分数是87%,考虑到这是一个简单的网络,这个分数不算太差。
五、模型优化
模型的优化主要从四个方面着手:
- 1、损失函数
在上面的模型中设计损失函数为MSEloss,这里将其更改为二元交叉熵损失((binary cross entropy loss)
self.loss_function = nn.BCELoss()
训练3轮,发现分数由87%提升到91%了
- 2、激活函数
Sigmoid激活函数的一个缺点是,当输入值变大时,梯度会变得非常小甚至消失。现在常用的是改进过的线性整流函数Leaky ReLU,也叫带泄露线性整流函数。
self.model = nn.Sequential(
nn.Linear(784, 200),
# nn.Sigmoid(),
nn.LeakyReLU(0.02),
nn.Linear(200, 10),
# nn.Sigmoid()
nn.LeakyReLU(0.02)
)
损失函数为原来的MSEloss,训练3轮,分数由87%上升到97%,这是一个很大的提升。
- 3 、优化器
上面模型所使用的是梯度下降法,该方法的一个缺点是会陷入损失函数的局部最小值,另一个缺点是对所有可学习参数都使用同一学习率。常见的替代方案是Adam,它利用动量减少陷入局部最小的可能,另外它对每个可学习参数使用单独的学习率,这些学习率随着每个参数在训练期间的变化而变化。
self.optimizer = torch.optim.Adam(self.parameters())
仅改变优化器发现模型达到和修改激活函数一样的效果,分数由87%提升到97%。
- 4、标准化
标准化是指减少网络中的参数和信号的取值范围,将均值转换为0,常见做法是在信号输入到神经网络前将其进行标准化。
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.Sigmoid(),
# nn.LeakyReLU(0.02),
nn.LayerNorm(200), # 标准化
nn.Linear(200, 10),
nn.Sigmoid()
# nn.LeakyReLU(0.02)
)
向网络中添加标准化,模型的分数87%提升到91%
将以上所有方法进行整合,由于二元交叉熵函数只能处理0~1的值,而LeakyReLU可能会输出范围外的值,将后一层激活函数保留为原来的Sigmoid函数:
self.model = nn.Sequential(
nn.Linear(784, 200),
# nn.Sigmoid(),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 10),
nn.Sigmoid()
# nn.LeakyReLU(0.02)
)
3周期训练完后,模型的分数为97%,整合的优化方案无法使模型分数大于97%。
END
参考资料
-[英]塔里克•拉希德(Tariq Rashid)著,韩江雷译. PyTorch生成对抗网络编程. 人民邮电出版社
边栏推荐
- 香港服务器有哪些常用的型号
- 简单的线性表的顺序表示实现,以及线性表的链式表示和实现、带头节点的单向链表,C语言简单实现一些基本功能
- 什么是SVN(Subversion)?
- html select tag assignment database query result
- nodejs installation and environment configuration
- MallBook 助力SKT思珂特教育集团,立足变化,拥抱敏捷交易
- Flask Framework Beginner-06-Add, Delete, Modify and Check the Database
- v-model
- Example 040: Reverse List
- 浏览器存储
猜你喜欢

螺旋矩阵_数组 | leecode刷题笔记

Security First: Tools You Need to Know to Implement DevSecOps Best Practices

Flask Framework Beginner-06-Add, Delete, Modify and Check the Database

Web APIs BOM- 操作浏览器:swiper 插件

织梦响应式酒店民宿住宿类网站织梦模板(自适应手机端)

Web APIs BOM - operating browser: swiper plug-in

【store商城项目01】环境准备以及测试

实例038:矩阵对角线之和

关联接口测试

ssh服务详解
随机推荐
【QT小记】QT中信号和槽的基本使用
大佬们,读取mysql300万单表要很长时间,有什么参数可以优惠,或者有什么办法可以快点
持续投入商品研发,叮咚买菜赢在了供应链投入上
C program compilation and predefined detailed explanation
计算首屏时间
什么是SVN(Subversion)?
SAP SD模块前台操作
实例038:矩阵对角线之和
Priority_queue element as a pointer, the overloaded operators
小甲鱼汇编笔记
企业虚拟偶像产生了实质性的价值效益
持续投入商品研发,叮咚买菜赢在了供应链投入上
一个项目的整体测试流程有哪几个阶段?测试方法有哪些?
参加Oracle OCP和MySQL OCP考试的学员怎样在VUE预约考试
实例041:类的方法与变量
Android interview questions and answer analysis of major factories in the first half of 2022 (continuously updated...)
Day13 Postman的使用
Continuing to pour money into commodities research and development, the ding-dong buy vegetables in win into the supply chain
Download install and create/run project for HBuilderX
实例040:逆序列表