当前位置:网站首页>Saving and reading of network model
Saving and reading of network model
2022-07-08 01:01:00 【booze-J】
article
Preservation of network model
The way 1
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 1( It not only saves the structure, but also saves some parameters in the network model ) Model structure + Model parameters
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth")
It not only saves the structure, but also saves some parameters in the network model ( Saved model structure + Model parameters ).
The way 2
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 2( take vgg16 The parameters in the network are saved as python Dictionary form in ) Model parameters ( The official recommendation )
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
take vgg16 The parameters in the network are saved as python Dictionary form in ( Save model parameters ( The official recommendation )), It is equivalent to loading the network model first , Then load the parameters .
Reading of network model
The way 1
The example code is as follows :
import torchvision
import torch
# The way 1 -》 Save the way 1 Load model + Parameters
model = torch.load("vgg16_method1.pth")
print("model",model)
Premise is Use the way the network model is saved 1, Use it directly model = torch.load("vgg16_method1.pth")
You can load the model and parameters .
The way 2
The example code is as follows :
import torchvision
import torch
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# The way 2 -》 Save the way 2 Load parameters
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# Add model parameters to the model
vgg16.load_state_dict(model)
Premise is Use the way the network model is saved 2, Use model = torch.load("vgg16_method2.pth")
Loading can only load model parameters , Well, you need to load the network model vgg16 = torchvision.models.vgg16(pretrained=False)
, Then add model parameters to the model vgg16.load_state_dict(model)
.
The trap of saving and reading network models
Write a simple network by yourself , Then use the method 1 preservation .
Sample code :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# Save the network model
torch.save(obj,"obj_method1.pth")
Run code , After the network model is saved successfully , Let's try loading the network model
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
Results found :
Loading the model directly in this way will report an error ! How to solve it ?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
# trap 1
model = torch.load("obj_method1.pth")
print("model",model)
Add , Code for building neural network , Run again , You can't report an error .
边栏推荐
- Tapdata 的 2.0 版 ,开源的 Live Data Platform 现已发布
- AI遮天传 ML-初识决策树
- [necessary for R & D personnel] how to make your own dataset and display it.
- 1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
- A network composed of three convolution layers completes the image classification task of cifar10 data set
- ReentrantLock 公平锁源码 第0篇
- Codeforces Round #804 (Div. 2)(A~D)
- C# ?,?.,?? .....
- Implementation of adjacency table of SQLite database storage directory structure 2-construction of directory tree
- Deep dive kotlin synergy (XXII): flow treatment
猜你喜欢
Service mesh introduction, istio overview
AI遮天传 ML-回归分析入门
13.模型的保存和載入
What has happened from server to cloud hosting?
9. Introduction to convolutional neural network
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
Su embedded training - Day9
Kubernetes Static Pod (静态Pod)
FOFA-攻防挑战记录
随机推荐
Deep dive kotlin synergy (XXII): flow treatment
C # generics and performance comparison
Basic types of 100 questions for basic grammar of Niuke
[reprint] solve the problem that CONDA installs pytorch too slowly
51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up
Cascade-LSTM: A Tree-Structured Neural Classifier for Detecting Misinformation Cascades(KDD20)
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch
Su embedded training - Day8
Su embedded training - Day5
New library online | cnopendata China Star Hotel data
完整的模型训练套路
Mathematical modeling -- knowledge map
炒股开户怎么最方便,手机上开户安全吗
【深度学习】AI一键换天
Y59. Chapter III kubernetes from entry to proficiency - continuous integration and deployment (III, II)
Password recovery vulnerability of foreign public testing
Service mesh introduction, istio overview
What is load balancing? How does DNS achieve load balancing?
Get started quickly using the local testing tool postman
130. Surrounding area