当前位置:网站首页>Pytorch RNN actual combat case_ MNIST handwriting font recognition
Pytorch RNN actual combat case_ MNIST handwriting font recognition
2022-07-06 10:25:00 【How about a song without trace】
# Model 1:Pytorch RNN Implementation process
# Load data set
# Make the data set iteratable ( Read one at a time Batch)
# Create model classes
# Initialize the model class
# Initialize loss class
# Training models
# 1. Load data set
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2、 Download datasets
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())
class_names = trainsets.classes # View category labels
print(class_names)
# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters
BASH_SIZE = 32 # The size of data read in each batch
EPOCHS = 10 # Ten rounds of training
# Create an iteratable object for the dataset , That is to say a batch One batch Read data from
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)
# View a batch of batch The data of
images, labels = next(iter(test_loader))
print(images.shape)
#6、 Defined function , Display a batch of data
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # mean value
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # The speed limit is limited to 0-1 Between
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Grid display
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. Definition RNN Model
class RNN_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(RNN_Model, self).__init__()
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first = True, nonlinearity='relu')
# Fully connected layer :
self.fc = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
out, hn = self.rnn(x, h0.detach())
out = self.fc(out[:, -1, :])
return out
# 8. Initialize model
input_dim = 28 # Input dimensions
hidden_dim = 100 # Hidden dimensions
layer_dim = 2 # 2 layer RNN
output_dim = 10 # Output dimension
# Instantiate the model and pass in parameters
model = RNN_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#11、 Output model parameters
length = len(list(model.parameters()))
#12、 Cycle print model parameters
for i in range(length):
print(' Parameters : %d' % (i+1))
print(list(model.parameters())[i].size())
# 13 、 model training
sequence_dim = 28 # Sequence length
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # Declare training
# One batch The data is converted to RNN The input dimension of
images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
labels = labels.to(device)
# Gradient clear ( Otherwise, it will continue to increase )
optimizer.zero_grad()
# Forward propagation
outputs = model(images)
# Calculate the loss
loss = criterion(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# The count is automatically incremented by one
iter += 1
# Model validation
if iter % 500 == 0:
model.eval() # Statement
# Calculate and verify accuracy
correct = 0.0
total = 0.0
# Iterative test set 、 get data 、 forecast
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# Model to predict
outputs = model(images)
# Get the subscript of the maximum value of the prediction probability
predict = torch.max(outputs.data,1)[1]
# Count the size of the test set
total += labels.size(0)
# Statistical judgment / Predict the correct quantity
if torch.cuda.is_available():
correct += (predict.gpu() == labels.gpu()).sum()
else:
correct += (predict == labels).sum()
# Calculation
accuracy = (correct / total)/ 100 * 100
# preservation accuracy, loss iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# Print information
print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# visualization loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()
# visualization accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('RNN')
plt.savefig('RNN_mnist.png')
plt.show()
边栏推荐
- 寶塔的安裝和flask項目部署
- MySQL ERROR 1040: Too many connections
- 如何搭建接口自动化测试框架?
- Mysql32 lock
- 实现以form-data参数发送post请求
- The 32-year-old fitness coach turned to a programmer and got an offer of 760000 a year. The experience of this older coder caused heated discussion
- South China Technology stack cnn+bilstm+attention
- 如何让shell脚本变成可执行文件
- Ueeditor internationalization configuration, supporting Chinese and English switching
- [after reading the series] how to realize app automation without programming (automatically start Kwai APP)
猜你喜欢
South China Technology stack cnn+bilstm+attention
If someone asks you about the consistency of database cache, send this article directly to him
MySQL36-数据库备份与恢复
Not registered via @enableconfigurationproperties, marked (@configurationproperties use)
MySQL35-主从复制
ByteTrack: Multi-Object Tracking by Associating Every Detection Box 论文阅读笔记()
[Julia] exit notes - Serial
Use xtrabackup for MySQL database physical backup
C miscellaneous lecture continued
用于实时端到端文本识别的自适应Bezier曲线网络
随机推荐
解决在window中远程连接Linux下的MySQL
MySQL实战优化高手10 生产经验:如何为数据库的监控系统部署可视化报表系统?
UEditor国际化配置,支持中英文切换
MySQL combat optimization expert 07 production experience: how to conduct 360 degree dead angle pressure test on the database in the production environment?
MySQL combat optimization expert 05 production experience: how to plan the database machine configuration in the real production environment?
基于Pytorch肺部感染识别案例(采用ResNet网络结构)
宝塔的安装和flask项目部署
实现微信公众号H5消息推送的超级详细步骤
在jupyter NoteBook使用Pytorch进行MNIST实现
ZABBIX introduction and installation
MySQL combat optimization expert 12 what does the memory data structure buffer pool look like?
Pytorch LSTM实现流程(可视化版本)
What should the redis cluster solution do? What are the plans?
MySQL real battle optimization expert 11 starts with the addition, deletion and modification of data. Review the status of buffer pool in the database
MySQL Real Time Optimization Master 04 discute de ce qu'est binlog en mettant à jour le processus d'exécution des déclarations dans le moteur de stockage InnoDB.
[after reading the series of must know] one of how to realize app automation without programming (preparation)
MySQL实战优化高手02 为了执行SQL语句,你知道MySQL用了什么样的架构设计吗?
MySQL ERROR 1040: Too many connections
MySQL实战优化高手04 借着更新语句在InnoDB存储引擎中的执行流程,聊聊binlog是什么?
NLP路线和资源