当前位置:网站首页>网络模型的保存与读取
网络模型的保存与读取
2022-07-07 23:11:00 【booze-J】
网络模型的保存
方式1
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1(不仅仅保存了结构还保存了网络模型中的一些参数) 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth")不仅仅保存了结构还保存了网络模型中的一些参数(保存了模型结构+模型参数)。
方式2
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式2(将vgg16网络中的参数保存成python中的字典形式) 模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth")将vgg16网络中的参数保存成python中的字典形式(保存模型参数(官方推荐)),相当于加载的时候需要先加载网络模型,再加载参数。
网络模型的读取
方式1
使用示例代码如下:
import torchvision
import torch
# 方式1 -》 保存方式1 加载模型+参数
model = torch.load("vgg16_method1.pth")
print("model",model)
前提是使用网络模型保存的方式1,直接使用model = torch.load("vgg16_method1.pth")便可以加载模型和参数。
方式2
使用示例代码如下:
import torchvision
import torch
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 方式2 -》 保存方式2 加载参数
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# 将模型参数添加到模型中
vgg16.load_state_dict(model)
前提是使用网络模型保存的方式2,使用model = torch.load("vgg16_method2.pth")加载只能把模型参数加载出来,好需要将网络模型加载出来vgg16 = torchvision.models.vgg16(pretrained=False),然后把模型参数添加到模型中vgg16.load_state_dict(model)。
网络模型的保存与读取的陷阱
自己写一个简单网络,然后使用方式1保存。
示例代码:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# 保存网络模型
torch.save(obj,"obj_method1.pth")
运行代码,网络模型保存成功之后,我们加载一下网络模型试试看
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
结果发现:
直接这样加载模型会报错!那怎么解决呢?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
# 陷阱1
model = torch.load("obj_method1.pth")
print("model",model)
在加载模型的代码中添加上,搭建神经网络的代码,再运行,就不会报错。
边栏推荐
- 取消select的默认样式的向下箭头和设置select默认字样
- 华为交换机S5735S-L24T4S-QA2无法telnet远程访问
- 赞!idea 如何单窗口打开多个项目?
- Jouer sonar
- Which securities company has a low, safe and reliable account opening commission
- 12.RNN应用于手写数字识别
- Application practice | the efficiency of the data warehouse system has been comprehensively improved! Data warehouse construction based on Apache Doris in Tongcheng digital Department
- NVIDIA Jetson test installation yolox process record
- Is it safe to open an account on the official website of Huatai Securities?
- 基于卷积神经网络的恶意软件检测方法
猜你喜欢

Jouer sonar

51 communicates with the Bluetooth module, and 51 drives the Bluetooth app to light up

12. RNN is applied to handwritten digit recognition

基于卷积神经网络的恶意软件检测方法

Image data preprocessing

They gathered at the 2022 ecug con just for "China's technological power"

From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run

jemter分布式

Deep dive kotlin synergy (XXII): flow treatment

赞!idea 如何单窗口打开多个项目?
随机推荐
C # generics and performance comparison
SDNU_ ACM_ ICPC_ 2022_ Summer_ Practice(1~2)
5.过拟合,dropout,正则化
letcode43:字符串相乘
[reprint] solve the problem that CONDA installs pytorch too slowly
DNS series (I): why does the updated DNS record not take effect?
Thinkphp内核工单系统源码商业开源版 多用户+多客服+短信+邮件通知
[Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
大二级分类产品页权重低,不收录怎么办?
Malware detection method based on convolutional neural network
Huawei switch s5735s-l24t4s-qa2 cannot be remotely accessed by telnet
13. Enregistrement et chargement des modèles
8.优化器
v-for遍历元素样式失效
Image data preprocessing
They gathered at the 2022 ecug con just for "China's technological power"
Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
接口测试要测试什么?
What is load balancing? How does DNS achieve load balancing?
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出