当前位置:网站首页>Pytorch RNN actual combat case_ MNIST handwriting font recognition
Pytorch RNN actual combat case_ MNIST handwriting font recognition
2022-07-06 10:25:00 【How about a song without trace】
# Model 1:Pytorch RNN Implementation process
# Load data set
# Make the data set iteratable ( Read one at a time Batch)
# Create model classes
# Initialize the model class
# Initialize loss class
# Training models
# 1. Load data set
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2、 Download datasets
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())
class_names = trainsets.classes # View category labels
print(class_names)
# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters
BASH_SIZE = 32 # The size of data read in each batch
EPOCHS = 10 # Ten rounds of training
# Create an iteratable object for the dataset , That is to say a batch One batch Read data from
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)
# View a batch of batch The data of
images, labels = next(iter(test_loader))
print(images.shape)
#6、 Defined function , Display a batch of data
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # mean value
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # The speed limit is limited to 0-1 Between
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Grid display
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. Definition RNN Model
class RNN_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(RNN_Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first = True, nonlinearity='relu')
# Fully connected layer :
self.fc = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, hn = self.rnn(x, h0.detach())
out = self.fc(out[:, -1, :])
return out
# 8. Initialize model
input_dim = 28 # Input dimensions
hidden_dim = 100 # Hidden dimensions
layer_dim = 2 # 2 layer RNN
output_dim = 10 # Output dimension
# Instantiate the model and pass in parameters
model = RNN_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#11、 Output model parameters
length = len(list(model.parameters()))
#12、 Cycle print model parameters
for i in range(length):
print(' Parameters : %d' % (i+1))
print(list(model.parameters())[i].size())
# 13 、 model training
sequence_dim = 28 # Sequence length
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # Declare training
# One batch The data is converted to RNN The input dimension of
images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
labels = labels.to(device)
# Gradient clear ( Otherwise, it will continue to increase )
optimizer.zero_grad()
# Forward propagation
outputs = model(images)
# Calculate the loss
loss = criterion(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# The count is automatically incremented by one
iter += 1
# Model validation
if iter % 500 == 0:
model.eval() # Statement
# Calculate and verify accuracy
correct = 0.0
total = 0.0
# Iterative test set 、 get data 、 forecast
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# Model to predict
outputs = model(images)
# Get the subscript of the maximum value of the prediction probability
predict = torch.max(outputs.data,1)[1]
# Count the size of the test set
total += labels.size(0)
# Statistical judgment / Predict the correct quantity
if torch.cuda.is_available():
correct += (predict.gpu() == labels.gpu()).sum()
else:
correct += (predict == labels).sum()
# Calculation
accuracy = (correct / total)/ 100 * 100
# preservation accuracy, loss iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# Print information
print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# visualization loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()
# visualization accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('RNN')
plt.savefig('RNN_mnist.png')
plt.show()
边栏推荐
- Jar runs with error no main manifest attribute
- ZABBIX introduction and installation
- In fact, the implementation of current limiting is not complicated
- Flash operation and maintenance script (running for a long time)
- MySQL实战优化高手08 生产经验:在数据库的压测过程中,如何360度无死角观察机器性能?
- MySQL的存储引擎
- 安装OpenCV时遇到的几种错误
- MySQL combat optimization expert 06 production experience: how does the production environment database of Internet companies conduct performance testing?
- 13 medical registration system_ [wechat login]
- NLP routes and resources
猜你喜欢
西南大学:胡航-关于学习行为和学习效果分析
Introduction tutorial of typescript (dark horse programmer of station B)
14 医疗挂号系统_【阿里云OSS、用户认证与就诊人】
13 医疗挂号系统_【 微信登录】
15 medical registration system_ [appointment registration]
16 medical registration system_ [order by appointment]
MySQL combat optimization expert 03 uses a data update process to preliminarily understand the architecture design of InnoDB storage engine
UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xd0 in position 0成功解决
Installation de la pagode et déploiement du projet flask
Contest3145 - the 37th game of 2021 freshman individual training match_ C: Tour guide
随机推荐
Not registered via @enableconfigurationproperties, marked (@configurationproperties use)
C miscellaneous lecture continued
Use of dataset of pytorch
[Julia] exit notes - Serial
MySQL Real Time Optimization Master 04 discute de ce qu'est binlog en mettant à jour le processus d'exécution des déclarations dans le moteur de stockage InnoDB.
pytorch的Dataset的使用
UnicodeDecodeError: ‘utf-8‘ codec can‘t decode byte 0xd0 in position 0成功解决
宝塔的安装和flask项目部署
Software test engineer development planning route
MySQL实战优化高手05 生产经验:真实生产环境下的数据库机器配置如何规划?
What is the current situation of the game industry in the Internet world?
13 医疗挂号系统_【 微信登录】
Complete web login process through filter
text 文本数据增强方法 data argumentation
Sed text processing
Mysql32 lock
Technology | diverse substrate formats
MySQL storage engine
UEditor国际化配置,支持中英文切换
MySQL combat optimization expert 06 production experience: how does the production environment database of Internet companies conduct performance testing?