当前位置:网站首页>Saving and reading of network model

Saving and reading of network model

2022-07-08 01:01:00 booze-J

Preservation of network model

The way 1

The example code is as follows :

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d

#  Load network model 
vgg16 = torchvision.models.vgg16(pretrained=False)

#  Save the way 1( It not only saves the structure, but also saves some parameters in the network model )  Model structure + Model parameters 
torch.save(vgg16,"vgg16_method1.pth")

torch.save(vgg16,"vgg16_method1.pth") It not only saves the structure, but also saves some parameters in the network model ( Saved model structure + Model parameters ).

The way 2

The example code is as follows :

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d

#  Load network model 
vgg16 = torchvision.models.vgg16(pretrained=False)
#  Save the way 2( take vgg16 The parameters in the network are saved as python Dictionary form in )  Model parameters ( The official recommendation )
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

torch.save(vgg16.state_dict(),"vgg16_method2.pth") take vgg16 The parameters in the network are saved as python Dictionary form in ( Save model parameters ( The official recommendation )), It is equivalent to loading the network model first , Then load the parameters .

Reading of network model

The way 1

The example code is as follows :

import torchvision
import torch

#  The way 1 -》  Save the way 1  Load model + Parameters 
model = torch.load("vgg16_method1.pth")
print("model",model)

Premise is Use the way the network model is saved 1, Use it directly model = torch.load("vgg16_method1.pth") You can load the model and parameters .

The way 2

The example code is as follows :

import torchvision
import torch
#  Load network model 
vgg16 = torchvision.models.vgg16(pretrained=False)
#  The way 2 -》  Save the way 2  Load parameters 
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
#  Add model parameters to the model 
vgg16.load_state_dict(model)

Premise is Use the way the network model is saved 2, Use model = torch.load("vgg16_method2.pth") Loading can only load model parameters , Well, you need to load the network model vgg16 = torchvision.models.vgg16(pretrained=False), Then add model parameters to the model vgg16.load_state_dict(model).

The trap of saving and reading network models

Write a simple network by yourself , Then use the method 1 preservation .
Sample code :

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
#  Building neural networks 
class Booze(nn.Module):

    #  Inherit nn.Module The initialization 
    def __init__(self):
        super(Booze, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    #  rewrite forward function 
    def forward(self,x):
        output = self.maxpool1(x)
        return output



obj = Booze()
#  Save the network model 
torch.save(obj,"obj_method1.pth")

Run code , After the network model is saved successfully , Let's try loading the network model

import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)

Results found :
 Insert picture description here
Loading the model directly in this way will report an error ! How to solve it ?

import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d

#  Building neural networks 
class Booze(nn.Module):

    #  Inherit nn.Module The initialization 
    def __init__(self):
        super(Booze, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    #  rewrite forward function 
    def forward(self,x):
        output = self.maxpool1(x)
        return output
#  trap 1
model = torch.load("obj_method1.pth")
print("model",model)

Add , Code for building neural network , Run again , You can't report an error .
 Insert picture description here

原网站

版权声明
本文为[booze-J]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/189/202207072310362398.html