当前位置:网站首页>Saving and reading of network model
Saving and reading of network model
2022-07-08 01:01:00 【booze-J】
article
Preservation of network model
The way 1
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 1( It not only saves the structure, but also saves some parameters in the network model ) Model structure + Model parameters
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth") It not only saves the structure, but also saves some parameters in the network model ( Saved model structure + Model parameters ).
The way 2
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 2( take vgg16 The parameters in the network are saved as python Dictionary form in ) Model parameters ( The official recommendation )
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth") take vgg16 The parameters in the network are saved as python Dictionary form in ( Save model parameters ( The official recommendation )), It is equivalent to loading the network model first , Then load the parameters .
Reading of network model
The way 1
The example code is as follows :
import torchvision
import torch
# The way 1 -》 Save the way 1 Load model + Parameters
model = torch.load("vgg16_method1.pth")
print("model",model)
Premise is Use the way the network model is saved 1, Use it directly model = torch.load("vgg16_method1.pth") You can load the model and parameters .
The way 2
The example code is as follows :
import torchvision
import torch
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# The way 2 -》 Save the way 2 Load parameters
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# Add model parameters to the model
vgg16.load_state_dict(model)
Premise is Use the way the network model is saved 2, Use model = torch.load("vgg16_method2.pth") Loading can only load model parameters , Well, you need to load the network model vgg16 = torchvision.models.vgg16(pretrained=False), Then add model parameters to the model vgg16.load_state_dict(model).
The trap of saving and reading network models
Write a simple network by yourself , Then use the method 1 preservation .
Sample code :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# Save the network model
torch.save(obj,"obj_method1.pth")
Run code , After the network model is saved successfully , Let's try loading the network model
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
Results found :
Loading the model directly in this way will report an error ! How to solve it ?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
# trap 1
model = torch.load("obj_method1.pth")
print("model",model)
Add , Code for building neural network , Run again , You can't report an error .
边栏推荐
- Summary of weidongshan phase II course content
- 基于人脸识别实现课堂抬头率检测
- Is it safe to open an account on the official website of Huatai Securities?
- [necessary for R & D personnel] how to make your own dataset and display it.
- FOFA-攻防挑战记录
- German prime minister says Ukraine will not receive "NATO style" security guarantee
- 基于卷积神经网络的恶意软件检测方法
- 2.非线性回归
- Implementation of adjacency table of SQLite database storage directory structure 2-construction of directory tree
- 图像数据预处理
猜你喜欢

Password recovery vulnerability of foreign public testing

SDNU_ACM_ICPC_2022_Summer_Practice(1~2)

Cve-2022-28346: Django SQL injection vulnerability

4.交叉熵

【笔记】常见组合滤波电路

130. 被圍繞的區域

Prediction of the victory or defeat of the League of heroes -- simple KFC Colonel

9. Introduction to convolutional neural network
![[note] common combined filter circuit](/img/2f/a8c2ef0d76dd7a45b50a64a928a9c8.png)
[note] common combined filter circuit

Lecture 1: the entry node of the link in the linked list
随机推荐
German prime minister says Ukraine will not receive "NATO style" security guarantee
8.优化器
130. Surrounding area
国内首次,3位清华姚班本科生斩获STOC最佳学生论文奖
Jemter distributed
Four stages of sand table deduction in attack and defense drill
Reentrantlock fair lock source code Chapter 0
[reprint] solve the problem that CONDA installs pytorch too slowly
接口测试要测试什么?
基于卷积神经网络的恶意软件检测方法
Is it safe to open an account on the official website of Huatai Securities?
Su embedded training - Day8
Analysis of 8 classic C language pointer written test questions
Y59. Chapter III kubernetes from entry to proficiency - continuous integration and deployment (III, II)
Deep dive kotlin synergy (XXII): flow treatment
大二级分类产品页权重低,不收录怎么办?
C # generics and performance comparison
Redis, do you understand the list
攻防演练中沙盘推演的4个阶段
《因果性Causality》教程,哥本哈根大学Jonas Peters讲授