当前位置:网站首页>Pytoch LSTM implementation process (visual version)
Pytoch LSTM implementation process (visual version)
2022-07-06 10:25:00 【How about a song without trace】
# Model 1:Pytorch LSTM Implementation process
# Load data set
# Make the data set iteratable ( Read one at a time Batch)
# Create model classes
# Initialize the model class
# Initialize loss class
# Training models
# 1. Load data set
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2、 Download datasets
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())
class_names = trainsets.classes # View category labels
print(class_names)
# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters
BASH_SIZE = 32 # The size of data read in each batch
EPOCHS = 10 # Ten rounds of training
# Create an iteratable object for the dataset , That is to say a batch One batch Read data from
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)
# View a batch of batch The data of
images, labels = next(iter(test_loader))
print(images.shape)
#6、 Defined function , Display a batch of data
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # mean value
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # The speed limit is limited to 0-1 Between
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Grid display
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. Definition RNN Model
class LSTM_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTM_Model, self).__init__() # Initializes the constructor in the parent class
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
# structure LSTM Model
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first = True)
# Fully connected layer :
self.fc = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
# The initialization hidden layer is installed in 0
# (layer_dim, batch_size, hidden_dim)
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# initialization cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# Detach hidden state , To avoid gradient explosion
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
# Only the state of the last hidden layer is needed
out = self.fc(out[:, -1, :])
return out
# 8. Initialize model
input_dim = 28 # Input dimensions
hidden_dim = 100 # Hidden dimensions
layer_dim = 1 # 1 layer
output_dim = 10 # Output dimension
# Instantiate the model and pass in parameters
model = LSTM_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#11、 Output model parameters
length = len(list(model.parameters()))
#12、 Cycle print model parameters
for i in range(length):
print(' Parameters : %d' % (i+1))
print(list(model.parameters())[i].size())
# 13 、 model training
sequence_dim = 28 # Sequence length
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # Declare training
# One batch The data is converted to LSTM The input dimension of
images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
labels = labels.to(device)
# Gradient clear ( Otherwise, it will continue to increase )
optimizer.zero_grad()
# Forward propagation
outputs = model(images)
# Calculate the loss
loss = criterion(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# The count is automatically incremented by one
iter += 1
# Model validation
if iter % 500 == 0:
model.eval() # Statement
# Calculate and verify accuracy
correct = 0.0
total = 0.0
# Iterative test set 、 get data 、 forecast
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# Model to predict
outputs = model(images)
# Get the subscript of the maximum value of the prediction probability
predict = torch.max(outputs.data,1)[1]
# Count the size of the test set
total += labels.size(0)
# Statistical judgment / Predict the correct quantity
if torch.cuda.is_available():
correct += (predict.gpu() == labels.gpu()).sum()
else:
correct += (predict == labels).sum()
# Calculation accuracy
accuracy = (correct / total)/ 100 * 100
# preservation accuracy, loss iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# Print information
print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# visualization loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('LSTM')
plt.show()
# visualization accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('LSTM')
plt.savefig('LSTM_accuracy.png')
plt.show()



边栏推荐
- 软件测试工程师必备之软技能:结构化思维
- MySQL實戰優化高手04 借著更新語句在InnoDB存儲引擎中的執行流程,聊聊binlog是什麼?
- The governor of New Jersey signed seven bills to improve gun safety
- MySQL实战优化高手08 生产经验:在数据库的压测过程中,如何360度无死角观察机器性能?
- Introduction tutorial of typescript (dark horse programmer of station B)
- The underlying logical architecture of MySQL
- 软件测试工程师必备之软技能:结构化思维
- flask运维脚本(长时间运行)
- MySQL36-数据库备份与恢复
- Anaconda3 安装cv2
猜你喜欢

如何搭建接口自动化测试框架?
![[Julia] exit notes - Serial](/img/d0/87f0d57ff910a666fbb67c0ae8a838.jpg)
[Julia] exit notes - Serial

What is the current situation of the game industry in the Internet world?

Mysql32 lock

PyTorch RNN 实战案例_MNIST手写字体识别

Record the first JDBC

17 医疗挂号系统_【微信支付】

A necessary soft skill for Software Test Engineers: structured thinking

实现以form-data参数发送post请求

Introduction tutorial of typescript (dark horse programmer of station B)
随机推荐
寶塔的安裝和flask項目部署
If someone asks you about the consistency of database cache, send this article directly to him
Zsh configuration file
[paper reading notes] - cryptographic analysis of short RSA secret exponents
软件测试工程师必备之软技能:结构化思维
第一篇博客
保姆级手把手教你用C语言写三子棋
Target detection -- yolov2 paper intensive reading
A necessary soft skill for Software Test Engineers: structured thinking
① BOKE
该不会还有人不懂用C语言写扫雷游戏吧
MySQL combat optimization expert 04 uses the execution process of update statements in the InnoDB storage engine to talk about what binlog is?
简单解决phpjm加密问题 免费phpjm解密工具
MySQL实战优化高手05 生产经验:真实生产环境下的数据库机器配置如何规划?
MySQL real battle optimization expert 11 starts with the addition, deletion and modification of data. Review the status of buffer pool in the database
数据库中间件_Mycat总结
In fact, the implementation of current limiting is not complicated
CDC: the outbreak of Listeria monocytogenes in the United States is related to ice cream products
Record the first JDBC
cmooc互联网+教育