当前位置:网站首页>pytorch applied to MNIST handwritten font recognition
pytorch applied to MNIST handwritten font recognition
2022-08-04 02:10:00 【windawdaysss】
前言
手写字体MNIST数据集是一组常见的图像,It is often used to evaluate and compare the performance of machine learning algorithms,本文使用pytorchframework to realize the recognition of this dataset,and optimize the results step by step.
一、数据集
MNIST数据集是由28x28大小的0-255A grayscale image of a range of pixel values(如下图所示),其中610,000 sheets are used to train the model,110,000 sheets are used to test the model.
The dataset is available from the link below:
训练数据集:
https://pjreddie.com/media/files/mnist_train.csv
测试数据集:
https://pjreddie.com/media/files/mnist_test.csv
The dataset has one row785个值,The first value is the numeric label in the image,其余784value is the pixel value of the image.
The example code for reading data is as follows:
import pandas
import matplotlib.pyplot as plt
df = pandas.read_csv(r'./data/mnist_train.csv', header=None)
# print(df.head()) # 显示前5行
# print(df.info()) # 显示DataFrame概况
row = 0
data = df.iloc[row]
label = data[0],
img = data[1:].values.reshape(28, 28)
plt.title('label = ' + str(label))
plt.imshow(img, interpolation='none', cmap='Blues')
plt.show()
二、建立模型
# 构建模型
import torch
import torch.nn as nn
from torch.utils.data import Dataset
class Classifier(nn.Module):
def __init__(self):
# 初始化pytorch父类
super().__init__()
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.Sigmoid(),
nn.Linear(200, 10),
nn.Sigmoid()
)
self.loss_function = nn.MSELoss()
self.optimizer = torch.optim.SGD(self.parameters(), lr=0.01)
self.counter = 0
self.progress = []
def forward(self, inputs):
return self.model(inputs)
def train_model(self, inputs, targets):
outputs = self.forward(inputs)
loss = self.loss_function(outputs, targets)
self.optimizer.zero_grad() # 梯度归零 ,Because the gradients computed by backpropagation accumulate
loss.backward() # 反向传播
self.optimizer.step() # 更新权重
# 可视化训练过程
self.counter += 1
if self.counter % 10 == 0:
self.progress.append(loss.item()) # Get the numbers in a single tensor
pass
if self.counter % 10000 == 0:
print('counter = ', self.counter)
pass
def plot_progress(self):
df = pandas.DataFrame(self.progress, columns=['loss'])
df.plot(ylim=(0, 1.0), figsize=(16, 8), alpha=0.1, marker='.', grid=True, yticks=(0, 0.25, 0.5))
plt.show()
pass
class MnistDataset(Dataset):
def __init__(self, csv_file):
self.data_df = pandas.read_csv(csv_file, header=None)
pass
def __len__(self):
return len(self.data_df)
def __getitem__(self, index):
label = self.data_df.iloc[index, 0]
target = torch.zeros((10))
target[label] = 1
image_value = torch.FloatTensor(self.data_df.iloc[index, 1:].values) / 255.0
return label, image_value, target
def plot_image(self, index):
arr = self.data_df.iloc[index, 1:].values.reshape(28, 28)
plt.title('label = ' + str(self.data_df.iloc[index, 0]))
plt.imshow(arr, interpolation='none', cmap='Blues')
plt.show()
pass
pass
The model framework is established above,并对训练过程进行可视化,Create a read data class.
三、训练分类模型
mnist_train_dataset = MnistDataset(r'./data/mnist_train.csv')
# mnist_dataset.plot_image(9)
# 训练分类模型
start_time = time.time()
C = Classifier()
epochs = 3 # 训练3轮
for i in range(epochs):
print('training epoch ', i+1, 'of', epochs)
for lable, image_tensor, target_tensor in mnist_train_dataset:
C.train_model(image_tensor, target_tensor)
pass
pass
C.plot_process()
print('run time = ', (time.time()-start_time) / 60)
训练3The round takes about less than approx3min,效率还不错
四、测试模型
# 测试模型
mnist_test_dataset = MnistDataset(r'./data/mnist_test.csv')
record = 19
mnist_test_dataset.plot_image(record) # numbers in the image
image_data = mnist_test_dataset[record][1]
output = C.forward(image_data)
pandas.DataFrame(output.detach().numpy()).plot(kind='bar', legend=False, ylim=(0, 1)) # 预测的数字
plt.show()
score = 0
items = 0
for label, img_tensor, label_tensor in mnist_test_dataset:
ans = C.forward(img_tensor)
if ans.argmax() == label:
score += 1
pass
items += 1
pass
print(score, items, score / items)
The test score of the model is 87%,考虑到这是一个简单的网络,This score is not too bad.
五、模型优化
The optimization of the model mainly starts from four aspects:
- 1、损失函数
The design loss function in the above model is MSEloss,Here it is changed to binary cross-entropy loss((binary cross entropy loss)
self.loss_function = nn.BCELoss()
训练3轮,Find the score by87%提升到91%了
- 2、激活函数
SigmoidOne disadvantage of activation functions is that,when the input value becomes larger,梯度会变得非常小甚至消失.Now commonly used is the improved linear rectification functionLeaky ReLU,Also called a leaky linear rectifier function.
self.model = nn.Sequential(
nn.Linear(784, 200),
# nn.Sigmoid(),
nn.LeakyReLU(0.02),
nn.Linear(200, 10),
# nn.Sigmoid()
nn.LeakyReLU(0.02)
)
The loss function is the originalMSEloss,训练3轮,分数由87%上升到97%,这是一个很大的提升.
- 3 、优化器
The above model uses the gradient descent method,A disadvantage of this method is that it gets stuck in local minima of the loss function,Another disadvantage is using the same learning rate for all learnable parameters.Common alternatives are Adam,It uses momentum to reduce the possibility of getting stuck in a local minimum,Additionally it uses a separate learning rate for each learnable parameter,这些学习率随着每个参数在训练期间的变化而变化.
self.optimizer = torch.optim.Adam(self.parameters())
Changing only the optimizer discovery model achieves the same effect as changing the activation function,分数由87%提升到97%.
- 4、标准化
Normalization refers to reducing the range of parameters and signals in the network,将均值转换为0,A common practice is to normalize the signal before feeding it into the neural network.
self.model = nn.Sequential(
nn.Linear(784, 200),
nn.Sigmoid(),
# nn.LeakyReLU(0.02),
nn.LayerNorm(200), # 标准化
nn.Linear(200, 10),
nn.Sigmoid()
# nn.LeakyReLU(0.02)
)
Add normalization to the network,模型的分数87%提升到91%
Combine all of the above methods,Since the binary cross-entropy function can only handle 0~1的值,而LeakyReLUOut-of-range values may be output,Leave the activation function of the latter layer as the originalSigmoid函数:
self.model = nn.Sequential(
nn.Linear(784, 200),
# nn.Sigmoid(),
nn.LeakyReLU(0.02),
nn.LayerNorm(200),
nn.Linear(200, 10),
nn.Sigmoid()
# nn.LeakyReLU(0.02)
)
3After cycle training,模型的分数为97%,The integrated optimization scheme cannot make the model score greater than 97%.
END
参考资料
-[英]塔里克•拉希德(Tariq Rashid)著,Translated by Han Jianglei. PyTorch生成对抗网络编程. 人民邮电出版社
边栏推荐
- LDO investigation
- 0.1 前言
- Continuing to invest in product research and development, Dingdong Maicai wins in supply chain investment
- Multithreading JUC Learning Chapter 1 Steps to Create Multithreading
- C program compilation and predefined detailed explanation
- esp32发布机器人电池电压到ros2(micro-ros+CoCube)
- 5.scrapy中间件&分布式爬虫
- priority_queue元素为指针时,重载运算符失效
- [QNX Hypervisor 2.2用户手册]10.3 vdev gic
- 实例035:设置输出颜色
猜你喜欢
随机推荐
Example 039: Inserting elements into an ordered list
Security First: Tools You Need to Know to Implement DevSecOps Best Practices
持续投入商品研发,叮咚买菜赢在了供应链投入上
小甲鱼汇编笔记
JS 保姆级贴心,从零教你手写实现一个防抖debounce方法
实例038:矩阵对角线之和
什么是SVN(Subversion)?
内网穿透-应用
priority_queue元素为指针时,重载运算符失效
STM32-遥感数据处理
在更一般意义上验算移位距离和假设
Small Turtle Compilation Notes
C# 构造函数业务场景测试项目
Please refer to dump files (if any exist) [date].dump, [date]-jvmRun[N].dump and [date].dumpstream.
实例035:设置输出颜色
一篇文章看懂JS闭包,从执行上下文角度解析有趣的闭包
Instance, 038: the sum of the diagonal matrix
P3384 【模板】轻重链剖分/树链剖分
lombok注解@RequiredArgsConstructor的使用
计算首屏时间