当前位置:网站首页>Pytorch RNN actual combat case_ MNIST handwriting font recognition

Pytorch RNN actual combat case_ MNIST handwriting font recognition

2022-07-06 10:25:00 How about a song without trace

 github Code address

#  Model 1:Pytorch RNN  Implementation process 
#  Load data set 
#  Make the data set iteratable ( Read one at a time Batch)
#  Create model classes 
#  Initialize the model class 
#  Initialize loss class 
#  Training models 
# 1.  Load data set 
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# 2、 Download datasets 
trainsets = datasets.MNIST(root = './data2',train = True,download = True,transform = transforms.ToTensor())
testsets = datasets.MNIST(root = './data2',train = False,transform=transforms.ToTensor())

class_names = trainsets.classes # View category labels 
print(class_names)

# 3、 Look at the dataset size shape
print(trainsets.data.shape)
print(trainsets.targets.shape)
#4、 Define super parameters 
BASH_SIZE = 32 # The size of data read in each batch 
EPOCHS = 10 # Ten rounds of training 
#  Create an iteratable object for the dataset , That is to say a batch One batch Read data from 
train_loader = torch.utils.data.DataLoader(dataset = trainsets, batch_size = BASH_SIZE,shuffle = True)
test_loader = torch.utils.data.DataLoader(dataset = testsets, batch_size = BASH_SIZE,shuffle = True)

#  View a batch of batch The data of 
images, labels = next(iter(test_loader))
print(images.shape)

#6、 Defined function , Display a batch of data 
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406]) #  mean value 
    std = np.array([0.229, 0.224, 0.225]) #  Standard deviation 
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1) #  The speed limit is limited to 0-1 Between 
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
# Grid display 
out = torchvision.utils.make_grid(images)

imshow(out)
# 7.  Definition RNN Model 
class RNN_Model(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(RNN_Model, self).__init__()
        self.hidden_dim = hidden_dim
        self.layer_dim = layer_dim
        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first = True, nonlinearity='relu')
        # Fully connected layer :
        self.fc = nn.Linear(hidden_dim,output_dim)
    def forward(self, x):
        h0 = torch.zeros(self.layer_dim, x.size(0), self.hidden_dim).requires_grad_().to(device)
        out, hn = self.rnn(x, h0.detach())
        out = self.fc(out[:, -1, :])
        return out
# 8.  Initialize model 
input_dim = 28 # Input dimensions 
hidden_dim = 100 # Hidden dimensions 
layer_dim = 2 # 2  layer RNN
output_dim = 10 # Output dimension 
# Instantiate the model and pass in parameters 
model = RNN_Model(input_dim, hidden_dim, layer_dim,output_dim)
# To determine if there is GPU
device = torch.device('cuda:()' if torch.cuda.is_available() else 'cpu')
#9、 Define the loss function 
criterion = nn.CrossEntropyLoss()
#10、 Define optimization functions 
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

#11、 Output model parameters 
length = len(list(model.parameters()))
#12、 Cycle print model parameters 
for i in range(length):
    print(' Parameters : %d' % (i+1))
    print(list(model.parameters())[i].size())


# 13 、 model training 
sequence_dim = 28 # Sequence length 
loss_list = [] # preservation loss
accuracy_list = [] # preservation accuracy
iteration_list = [] # Number of save cycles 
iter = 0
for epoch in range(EPOCHS):
    for i, (images, labels) in enumerate(train_loader):
        model.train() # Declare training 
        # One batch The data is converted to RNN The input dimension of 
        images = images.view(-1, sequence_dim, input_dim).requires_grad_().to(device)
        labels = labels.to(device)
        # Gradient clear ( Otherwise, it will continue to increase )
        optimizer.zero_grad()
        # Forward propagation 
        outputs = model(images)
        # Calculate the loss 
        loss = criterion(outputs, labels)
        # Back propagation 
        loss.backward()
        # Update parameters 
        optimizer.step()
        # The count is automatically incremented by one 
        iter += 1
        # Model validation 
        if iter % 500 == 0:
            model.eval() # Statement 
            # Calculate and verify accuracy
            correct = 0.0
            total = 0.0
            # Iterative test set 、 get data 、 forecast 
            for images, labels in test_loader:
                images = images.view(-1, sequence_dim, input_dim).to(device)
                # Model to predict 
                outputs = model(images)
                # Get the subscript of the maximum value of the prediction probability 
                predict = torch.max(outputs.data,1)[1]
                # Count the size of the test set 
                total += labels.size(0)
                #  Statistical judgment / Predict the correct quantity 
                if torch.cuda.is_available():
                    correct += (predict.gpu() == labels.gpu()).sum()
                else:
                    correct += (predict == labels).sum()
            # Calculation 
            accuracy = (correct / total)/ 100 * 100
            # preservation accuracy, loss iteration
            loss_list.append(loss.data)
            accuracy_list.append(accuracy)
            iteration_list.append(iter)
            #  Print information 
            print("epoch : {}, Loss : {}, Accuracy : {}".format(iter, loss.item(), accuracy))

#  visualization  loss
plt.plot(iteration_list, loss_list)
plt.xlabel('Number of Iteration')
plt.ylabel('Loss')
plt.title('RNN')
plt.show()

# visualization  accuracy
plt.plot(iteration_list, accuracy_list, color = 'r')
plt.xlabel('Number of Iteration')
plt.ylabel('Accuracy')
plt.title('RNN')
plt.savefig('RNN_mnist.png')
plt.show()


 

 

原网站

版权声明
本文为[How about a song without trace]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/187/202207060910587091.html

随机推荐