当前位置:网站首页>网络模型的保存与读取
网络模型的保存与读取
2022-07-07 23:11:00 【booze-J】
网络模型的保存
方式1
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1(不仅仅保存了结构还保存了网络模型中的一些参数) 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth")
不仅仅保存了结构还保存了网络模型中的一些参数(保存了模型结构+模型参数)。
方式2
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式2(将vgg16网络中的参数保存成python中的字典形式) 模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
将vgg16网络中的参数保存成python中的字典形式(保存模型参数(官方推荐)),相当于加载的时候需要先加载网络模型,再加载参数。
网络模型的读取
方式1
使用示例代码如下:
import torchvision
import torch
# 方式1 -》 保存方式1 加载模型+参数
model = torch.load("vgg16_method1.pth")
print("model",model)
前提是使用网络模型保存的方式1,直接使用model = torch.load("vgg16_method1.pth")
便可以加载模型和参数。
方式2
使用示例代码如下:
import torchvision
import torch
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 方式2 -》 保存方式2 加载参数
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# 将模型参数添加到模型中
vgg16.load_state_dict(model)
前提是使用网络模型保存的方式2,使用model = torch.load("vgg16_method2.pth")
加载只能把模型参数加载出来,好需要将网络模型加载出来vgg16 = torchvision.models.vgg16(pretrained=False)
,然后把模型参数添加到模型中vgg16.load_state_dict(model)
。
网络模型的保存与读取的陷阱
自己写一个简单网络,然后使用方式1保存。
示例代码:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# 保存网络模型
torch.save(obj,"obj_method1.pth")
运行代码,网络模型保存成功之后,我们加载一下网络模型试试看
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
结果发现:
直接这样加载模型会报错!那怎么解决呢?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
# 陷阱1
model = torch.load("obj_method1.pth")
print("model",model)
在加载模型的代码中添加上,搭建神经网络的代码,再运行,就不会报错。
边栏推荐
- 13.模型的保存和載入
- Jouer sonar
- Langchao Yunxi distributed database tracing (II) -- source code analysis
- What is load balancing? How does DNS achieve load balancing?
- 5.过拟合,dropout,正则化
- [necessary for R & D personnel] how to make your own dataset and display it.
- Fofa attack and defense challenge record
- 炒股开户怎么最方便,手机上开户安全吗
- v-for遍历元素样式失效
- 8道经典C语言指针笔试题解析
猜你喜欢
New library launched | cnopendata China Time-honored enterprise directory
My best game based on wechat applet development
ReentrantLock 公平锁源码 第0篇
What has happened from server to cloud hosting?
After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
Password recovery vulnerability of foreign public testing
第一讲:链表中环的入口结点
13. Enregistrement et chargement des modèles
What does interface testing test?
随机推荐
5.过拟合,dropout,正则化
NTT template for Tourism
1293_FreeRTOS中xTaskResumeAll()接口的实现分析
串口接收一包数据
8.优化器
Image data preprocessing
LeetCode刷题
Where is the big data open source project, one-stop fully automated full life cycle operation and maintenance steward Chengying (background)?
9.卷积神经网络介绍
Summary of weidongshan phase II course content
Kubernetes static pod (static POD)
12.RNN应用于手写数字识别
Invalid V-for traversal element style
How to learn a new technology (programming language)
13.模型的保存和载入
1293_ Implementation analysis of xtask resumeall() interface in FreeRTOS
How to insert highlighted code blocks in WPS and word
How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
Su embedded training - day4
Lecture 1: the entry node of the link in the linked list