当前位置:网站首页>Saving and reading of network model
Saving and reading of network model
2022-07-08 01:01:00 【booze-J】
article
Preservation of network model
The way 1
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 1( It not only saves the structure, but also saves some parameters in the network model ) Model structure + Model parameters
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth") It not only saves the structure, but also saves some parameters in the network model ( Saved model structure + Model parameters ).
The way 2
The example code is as follows :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# Save the way 2( take vgg16 The parameters in the network are saved as python Dictionary form in ) Model parameters ( The official recommendation )
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth") take vgg16 The parameters in the network are saved as python Dictionary form in ( Save model parameters ( The official recommendation )), It is equivalent to loading the network model first , Then load the parameters .
Reading of network model
The way 1
The example code is as follows :
import torchvision
import torch
# The way 1 -》 Save the way 1 Load model + Parameters
model = torch.load("vgg16_method1.pth")
print("model",model)
Premise is Use the way the network model is saved 1, Use it directly model = torch.load("vgg16_method1.pth") You can load the model and parameters .
The way 2
The example code is as follows :
import torchvision
import torch
# Load network model
vgg16 = torchvision.models.vgg16(pretrained=False)
# The way 2 -》 Save the way 2 Load parameters
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# Add model parameters to the model
vgg16.load_state_dict(model)
Premise is Use the way the network model is saved 2, Use model = torch.load("vgg16_method2.pth") Loading can only load model parameters , Well, you need to load the network model vgg16 = torchvision.models.vgg16(pretrained=False), Then add model parameters to the model vgg16.load_state_dict(model).
The trap of saving and reading network models
Write a simple network by yourself , Then use the method 1 preservation .
Sample code :
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# Save the network model
torch.save(obj,"obj_method1.pth")
Run code , After the network model is saved successfully , Let's try loading the network model
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
Results found :
Loading the model directly in this way will report an error ! How to solve it ?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# Building neural networks
class Booze(nn.Module):
# Inherit nn.Module The initialization
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# rewrite forward function
def forward(self,x):
output = self.maxpool1(x)
return output
# trap 1
model = torch.load("obj_method1.pth")
print("model",model)
Add , Code for building neural network , Run again , You can't report an error .
边栏推荐
- tourist的NTT模板
- 图像数据预处理
- 第四期SFO销毁,Starfish OS如何对SFO价值赋能?
- y59.第三章 Kubernetes从入门到精通 -- 持续集成与部署(三二)
- Kubernetes Static Pod (静态Pod)
- 13.模型的保存和载入
- Which securities company has a low, safe and reliable account opening commission
- Is it safe to speculate in stocks on mobile phones?
- Langchao Yunxi distributed database tracing (II) -- source code analysis
- 8.优化器
猜你喜欢

New library online | cnopendata China Star Hotel data

12.RNN应用于手写数字识别

AI遮天传 ML-回归分析入门

Interface test advanced interface script use - apipost (pre / post execution script)

130. 被围绕的区域

FOFA-攻防挑战记录

8道经典C语言指针笔试题解析

QT adds resource files, adds icons for qaction, establishes signal slot functions, and implements

Get started quickly using the local testing tool postman

fabulous! How does idea open multiple projects in a single window?
随机推荐
8道经典C语言指针笔试题解析
The method of server defense against DDoS, Hangzhou advanced anti DDoS IP section 103.219.39 x
完整的模型训练套路
New library online | cnopendata China Star Hotel data
基于微信小程序开发的我最在行的小游戏
12. RNN is applied to handwritten digit recognition
【GO记录】从零开始GO语言——用GO语言做一个示波器(一)GO语言基础
Su embedded training - Day9
Basic mode of service mesh
13.模型的保存和载入
tourist的NTT模板
[go record] start go language from scratch -- make an oscilloscope with go language (I) go language foundation
What does interface testing test?
Codeforces Round #804 (Div. 2)(A~D)
手机上炒股安全么?
牛客基础语法必刷100题之基本类型
攻防演练中沙盘推演的4个阶段
STL -- common function replication of string class
3.MNIST数据集分类
网络模型的保存与读取