当前位置:网站首页>Pytoch LSTM implementation process (visual version)
Pytoch LSTM implementation process (visual version)
2022-07-06 10:25:00 【How about a song without trace】
# Model 1:Pytorch LSTM Implementation process
# Load data set
# Make the data set iteratable ( Read one at a time Batch)
# Create model classes
# Initialize the model class
# Initialize loss class
# Training models
# 1. Load data set
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
# 2、 Download datasets
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())
class_names = trainsets.classes # View category labels
print(class_names)
# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters
BASH_SIZE = 32 # The size of data read in each batch
EPOCHS = 10 # Ten rounds of training
# Create an iteratable object for the dataset , That is to say a batch One batch Read data from
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)
# View a batch of batch The data of
images, labels = next(iter(test_loader))
print(images.shape)
#6、 Defined function , Display a batch of data
def imshow(inp, title=None):
inp = inp.numpy().transpose((1, 2, 0))
mean = np.array([0.485, 0.456, 0.406]) # mean value
std = np.array([0.229, 0.224, 0.225]) # Standard deviation
inp = std * inp + mean
inp = np.clip(inp, 0, 1) # The speed limit is limited to 0-1 Between
plt.imshow(inp)
if title is not None:
plt.title(title)
plt.pause(0.001)
# Grid display
out = torchvision.utils.make_grid(images)
imshow(out)
# 7. Definition RNN Model
class LSTM_Model(nn.Module):
def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
super(LSTM_Model, self).__init__() # Initializes the constructor in the parent class
self.hidden_dim = hidden_dim
self.layer_dim = layer_dim
# structure LSTM Model
self.lstm = nn.LSTM(input_dim, hidden_dim, layer_dim, batch_first = True)
# Fully connected layer :
self.fc = nn.Linear(hidden_dim,output_dim)
def forward(self, x):
# The initialization hidden layer is installed in 0
# (layer_dim, batch_size, hidden_dim)
h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# initialization cell state
c0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
# Detach hidden state , To avoid gradient explosion
out, (hn, cn) = self.lstm(x, (h0.detach(), c0.detach()))
# Only the state of the last hidden layer is needed
out = self.fc(out[:, -1, :])
return out
# 8. Initialize model
input_dim = 28 # Input dimensions
hidden_dim = 100 # Hidden dimensions
layer_dim = 1 # 1 layer
output_dim = 10 # Output dimension
# Instantiate the model and pass in parameters
model = LSTM_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)
#11、 Output model parameters
length = len(list(model.parameters()))
#12、 Cycle print model parameters
for i in range(length):
print(' Parameters : %d' % (i+1))
print(list(model.parameters())[i].size())
# 13 、 model training
sequence_dim = 28 # Sequence length
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles
iter = 0
for epoch in range(EPOCHS):
for i, (images, labels) in enumerate(train_loader):
model.train() # Declare training
# One batch The data is converted to LSTM The input dimension of
images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
labels = labels.to(device)
# Gradient clear ( Otherwise, it will continue to increase )
optimizer.zero_grad()
# Forward propagation
outputs = model(images)
# Calculate the loss
loss = criterion(outputs, labels)
# Back propagation
loss.backward()
# Update parameters
optimizer.step()
# The count is automatically incremented by one
iter += 1
# Model validation
if iter % 500 == 0:
model.eval() # Statement
# Calculate and verify accuracy
correct = 0.0
total = 0.0
# Iterative test set 、 get data 、 forecast
for images, labels in test_loader:
images = images.view(-1, sequence_dim, input_dim).to(device)
# Model to predict
outputs = model(images)
# Get the subscript of the maximum value of the prediction probability
predict = torch.max(outputs.data,1)[1]
# Count the size of the test set
total += labels.size(0)
# Statistical judgment / Predict the correct quantity
if torch.cuda.is_available():
correct += (predict.gpu() == labels.gpu()).sum()
else:
correct += (predict == labels).sum()
# Calculation accuracy
accuracy = (correct / total)/ 100 * 100
# preservation accuracy, loss iteration
loss_list.append(loss.data)
accuracy_list.append(accuracy)
iteration_list.append(iter)
# Print information
print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))
# visualization loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('LSTM')
plt.show()
# visualization accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('LSTM')
plt.savefig('LSTM_accuracy.png')
plt.show()



边栏推荐
- Solve the problem of remote connection to MySQL under Linux in Windows
- MySQL learning diary (II)
- Isn't there anyone who doesn't know how to write mine sweeping games in C language
- MySQL ERROR 1040: Too many connections
- Zsh configuration file
- 西南大学:胡航-关于学习行为和学习效果分析
- 颜值爆表,推荐两款JSON可视化工具,配合Swagger使用真香
- 16 medical registration system_ [order by appointment]
- [after reading the series of must know] one of how to realize app automation without programming (preparation)
- MySQL real battle optimization expert 11 starts with the addition, deletion and modification of data. Review the status of buffer pool in the database
猜你喜欢

Security design verification of API interface: ticket, signature, timestamp

保姆级手把手教你用C语言写三子棋

C miscellaneous two-way circular linked list

The programming ranking list came out in February. Is the result as you expected?

Docker MySQL solves time zone problems
![[C language] deeply analyze the underlying principle of data storage](/img/d6/1c0cd38c75da0d0cc1df7f36938cfb.png)
[C language] deeply analyze the underlying principle of data storage

Mysql32 lock

Super detailed steps to implement Wechat public number H5 Message push

cmooc互联网+教育

Implement context manager through with
随机推荐
Export virtual machines from esxi 6.7 using OVF tool
Technology | diverse substrate formats
MySQL32-锁
17 medical registration system_ [wechat Payment]
MySQL实战优化高手12 Buffer Pool这个内存数据结构到底长个什么样子?
① BOKE
MySQL35-主从复制
使用OVF Tool工具从Esxi 6.7中导出虚拟机
用于实时端到端文本识别的自适应Bezier曲线网络
[after reading the series] how to realize app automation without programming (automatically start Kwai APP)
Sed text processing
实现微信公众号H5消息推送的超级详细步骤
Transactions have four characteristics?
Record the first JDBC
History of object recognition
A necessary soft skill for Software Test Engineers: structured thinking
Installation of pagoda and deployment of flask project
高并发系统的限流方案研究,其实限流实现也不复杂
MySQL storage engine
15 medical registration system_ [appointment registration]