当前位置:网站首页>Pytorch model loading and saving, and training based on the saved model

Pytorch model loading and saving, and training based on the saved model

2022-06-12 01:20:00 I SONGFENG water month

The two methods :
1. Save only parameters ( The official recommendation , Consumes less storage space ), Explain this method in detail .
2. Save the entire model structure

One . Save only parameters

1. preservation :
Method 1 :

torch.save(model.state_dict(), path)

model: Instance variables of the defined model , Such as model = resnet(),path Is the path where the model is saved , Such as path = “./model.pth”,path = “./model.pkl”,path = “./model.tar”, Be sure to add a suffix .
Method 2 :
If you want to save a training parameter and model , Then you can use a dictionary to save :

state = {
    "model": model.state_dict(), "optimizer": optimizer.state_dict(), 'epoch': epoch}
torch,save(state, path)

2. load
Model loading for method 1 :

model.load_state_dict(torrch.load(path))

Model loading for the second method :

checkpoint = torch.load(path)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
epoch = checkpoint['epoch']

The model saving method that only saves parameters is used when loading the model , The loaded model must be consistent with the pre-defined model , Otherwise, the network structure needs to be adjusted ( It is generally to adjust the final output structure ), And in the instance object of the model ( Suppose, model) Load on , That is to say, before using the above loading statements, you have already defined the same... As the original model Net, And instantiate model=Net( ) .
If every one epoch Or every n individual epoch Save the parameters once , Different settings can be made path, Such as path=’./model’ + str(epoch) +’.pth’, such , Different epoch Can be saved in different files , It is also the same to select the model parameter with the highest recognition rate , Just add... Before saving the model statement if Just judge the sentence .

example : Save only the latest parameters

#-*- coding:utf-8 -*-

''' This document is used to illustrate pytorch How to save and load files '''


import torch as torch
import torchvision as tv
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.transforms as transforms
from torchvision.transforms import ToPILImage
import torch.backends.cudnn as cudnn
import datetime
import argparse

#  Parameter declarations 
batch_size = 32
epochs = 10
WORKERS = 0   # dataloder Number of threads 
test_flag = True  # Test mark ,True Load the saved model for testing  
ROOT = '/home/pxt/pytorch/cifar'  # MNIST Data set save path 
log_dir = '/home/pxt/pytorch/logs/cifar_model.pth'  #  Model save path 

#  load MNIST Data sets 
transform = tv.transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])

train_data = tv.datasets.CIFAR10(root=ROOT, train=True, download=True, transform=transform)
test_data = tv.datasets.CIFAR10(root=ROOT, train=False, download=False, transform=transform)

train_load = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=WORKERS)
test_load = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=WORKERS)


#  Build a model 
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, 3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, 3, padding=1)
        self.conv3 = nn.Conv2d(128, 256, 3, padding=1)
        self.conv4 = nn.Conv2d(256, 256, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(256 * 8 * 8, 1024)
        self.fc2 = nn.Linear(1024, 256)
        self.fc3 = nn.Linear(256, 10)
    
    
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = F.relu(self.conv3(x))
        x = self.pool(F.relu(self.conv4(x)))
        x = x.view(-1, x.size()[1] * x.size()[2] * x.size()[3])
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net().cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)


#  model training 
def train(model, train_loader, epoch):
    model.train()
    train_loss = 0
    for i, data in enumerate(train_loader, 0):
        x, y = data
        x = x.cuda()
        y = y.cuda()
        optimizer.zero_grad()
        y_hat = model(x)
        loss = criterion(y_hat, y)
        loss.backward()
        optimizer.step()
        train_loss += loss
    loss_mean = train_loss / (i+1)
    print('Train Epoch: {}\t Loss: {:.6f}'.format(epoch, loss_mean.item()))

#  Model test 
def test(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for i, data in enumerate(test_loader, 0):
            x, y = data
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            y_hat = model(x)
            test_loss += criterion(y_hat, y).item()
            pred = y_hat.max(1, keepdim=True)[1]
            correct += pred.eq(y.view_as(pred)).sum().item()
        test_loss /= (i+1)
        print('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
            test_loss, correct, len(test_data), 100. * correct / len(test_data)))


def main():

    #  If test_flag=True, Then load the saved model 
    if test_flag:
        #  Load the saved model and directly verify the test machine , Do not carry out the following steps of this module 
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        epochs = checkpoint['epoch']
        test(model, test_load)
        return

    for epoch in range(0, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        #  Save the model 
        state = {
    'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

if __name__ == '__main__':
    main()

3. Train on the loaded model :

def main():

    #  If test_flag=True, Then load the saved model 
    if test_flag:
        #  Load the saved model and directly verify the test machine , Do not carry out the following steps of this module 
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        test(model, test_load)
        return

    #  If there is a saved model , Then load the model , And continue training on the basis of it 
    if os.path.exists(log_dir):
        checkpoint = torch.load(log_dir)
        model.load_state_dict(checkpoint['model'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        start_epoch = checkpoint['epoch']
        print(' load  epoch {}  success !'.format(start_epoch))
    else:
        start_epoch = 0
        print(' No saved model , Will start training from scratch !')

    for epoch in range(start_epoch+1, epochs):
        train(model, train_load, epoch)
        test(model, test_load)
        #  Save the model 
        state = {
    'model':model.state_dict(), 'optimizer':optimizer.state_dict(), 'epoch':epoch}
        torch.save(state, log_dir)

Two . Save the entire model

1. preservation

torch.save(model, path)

2. load

model = torch.load(path)

Reference links to this blog post :[https://www.jianshu.com/p/1cd6333128a1]

原网站

版权声明
本文为[I SONGFENG water month]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/03/202203011409433914.html