当前位置:网站首页>Pytorch RNN actual combat case_ MNIST handwriting font recognition
Pytorch RNN actual combat case_ MNIST handwriting font recognition
2022-07-06 10:25:00 【How about a song without trace】
# Model 1:Pytorch RNN Implementation process
# Load data set
# Make the data set iteratable ( Read one at a time Batch)
# Create model classes
# Initialize the model class
# Initialize loss class
# Training models
# 1. Load data set
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2、 Download datasets
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())
class_names = trainsets.classes # View category labels
print(class_names)
# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters
BASH_SIZE = 32 # The size of data read in each batch
EPOCHS = 10 # Ten rounds of training
# Create an iteratable object for the dataset , That is to say a batch One batch Read data from
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)
# View a batch of batch The data of
images, labels = next(iter(test_loader))
print(images.shape)
#6、 Defined function , Display a batch of data
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # mean value
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # The speed limit is limited to 0-1 Between
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Grid display
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. Definition RNN Model
class RNN_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(RNN_Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first = True, nonlinearity='relu')
# Fully connected layer :
self.fc = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, hn = self.rnn(x, h0.detach())
out = self.fc(out[:, -1, :])
return out
# 8. Initialize model
input_dim = 28 # Input dimensions
hidden_dim = 100 # Hidden dimensions
layer_dim = 2 # 2 layer RNN
output_dim = 10 # Output dimension
# Instantiate the model and pass in parameters
model = RNN_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#11、 Output model parameters
length = len(list(model.parameters()))
#12、 Cycle print model parameters
for i in range(length):
print(' Parameters : %d' % (i+1))
print(list(model.parameters())[i].size())
# 13 、 model training
sequence_dim = 28 # Sequence length
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # Declare training
# One batch The data is converted to RNN The input dimension of
images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
labels = labels.to(device)
# Gradient clear ( Otherwise, it will continue to increase )
optimizer.zero_grad()
# Forward propagation
outputs = model(images)
# Calculate the loss
loss = criterion(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# The count is automatically incremented by one
iter += 1
# Model validation
if iter % 500 == 0:
model.eval() # Statement
# Calculate and verify accuracy
correct = 0.0
total = 0.0
# Iterative test set 、 get data 、 forecast
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# Model to predict
outputs = model(images)
# Get the subscript of the maximum value of the prediction probability
predict = torch.max(outputs.data,1)[1]
# Count the size of the test set
total += labels.size(0)
# Statistical judgment / Predict the correct quantity
if torch.cuda.is_available():
correct += (predict.gpu() == labels.gpu()).sum()
else:
correct += (predict == labels).sum()
# Calculation
accuracy = (correct / total)/ 100 * 100
# preservation accuracy, loss iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# Print information
print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# visualization loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()
# visualization accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('RNN')
plt.savefig('RNN_mnist.png')
plt.show()
边栏推荐
- PyTorch RNN 实战案例_MNIST手写字体识别
- How to build an interface automation testing framework?
- MySQL实战优化高手12 Buffer Pool这个内存数据结构到底长个什么样子?
- MySQL learning diary (II)
- C miscellaneous dynamic linked list operation
- Zsh configuration file
- Security design verification of API interface: ticket, signature, timestamp
- 14 medical registration system_ [Alibaba cloud OSS, user authentication and patient]
- A new understanding of RMAN retention policy recovery window
- MySQL实战优化高手07 生产经验:如何对生产环境中的数据库进行360度无死角压测?
猜你喜欢
Ueeditor internationalization configuration, supporting Chinese and English switching
MySQL实战优化高手04 借着更新语句在InnoDB存储引擎中的执行流程,聊聊binlog是什么?
使用OVF Tool工具从Esxi 6.7中导出虚拟机
The appearance is popular. Two JSON visualization tools are recommended for use with swagger. It's really fragrant
软件测试工程师必备之软技能:结构化思维
History of object recognition
Export virtual machines from esxi 6.7 using OVF tool
MySQL的存储引擎
Mysql32 lock
How to make shell script executable
随机推荐
Time in TCP state_ The role of wait?
实现微信公众号H5消息推送的超级详细步骤
MySQL ERROR 1040: Too many connections
The programming ranking list came out in February. Is the result as you expected?
在jupyter NoteBook使用Pytorch进行MNIST实现
Target detection -- yolov2 paper intensive reading
Redis集群方案应该怎么做?都有哪些方案?
Not registered via @enableconfigurationproperties, marked (@configurationproperties use)
软件测试工程师必备之软技能:结构化思维
Use JUnit unit test & transaction usage
13 医疗挂号系统_【 微信登录】
如何搭建接口自动化测试框架?
安装OpenCV时遇到的几种错误
软件测试工程师发展规划路线
Windchill配置远程Oracle数据库连接
Super detailed steps to implement Wechat public number H5 Message push
14 医疗挂号系统_【阿里云OSS、用户认证与就诊人】
ByteTrack: Multi-Object Tracking by Associating Every Detection Box 论文阅读笔记()
Download and installation of QT Creator
基于Pytorch的LSTM实战160万条评论情感分类