当前位置:网站首页>网络模型的保存与读取
网络模型的保存与读取
2022-07-07 23:11:00 【booze-J】
网络模型的保存
方式1
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1(不仅仅保存了结构还保存了网络模型中的一些参数) 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")
torch.save(vgg16,"vgg16_method1.pth")
不仅仅保存了结构还保存了网络模型中的一些参数(保存了模型结构+模型参数)。
方式2
使用示例代码如下:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式2(将vgg16网络中的参数保存成python中的字典形式) 模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
将vgg16网络中的参数保存成python中的字典形式(保存模型参数(官方推荐)),相当于加载的时候需要先加载网络模型,再加载参数。
网络模型的读取
方式1
使用示例代码如下:
import torchvision
import torch
# 方式1 -》 保存方式1 加载模型+参数
model = torch.load("vgg16_method1.pth")
print("model",model)
前提是使用网络模型保存的方式1,直接使用model = torch.load("vgg16_method1.pth")
便可以加载模型和参数。
方式2
使用示例代码如下:
import torchvision
import torch
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 方式2 -》 保存方式2 加载参数
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# 将模型参数添加到模型中
vgg16.load_state_dict(model)
前提是使用网络模型保存的方式2,使用model = torch.load("vgg16_method2.pth")
加载只能把模型参数加载出来,好需要将网络模型加载出来vgg16 = torchvision.models.vgg16(pretrained=False)
,然后把模型参数添加到模型中vgg16.load_state_dict(model)
。
网络模型的保存与读取的陷阱
自己写一个简单网络,然后使用方式1保存。
示例代码:
import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
obj = Booze()
# 保存网络模型
torch.save(obj,"obj_method1.pth")
运行代码,网络模型保存成功之后,我们加载一下网络模型试试看
import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)
结果发现:
直接这样加载模型会报错!那怎么解决呢?
import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):
# 继承nn.Module的初始化
def __init__(self):
super(Booze, self).__init__()
self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)
# 重写forward函数
def forward(self,x):
output = self.maxpool1(x)
return output
# 陷阱1
model = torch.load("obj_method1.pth")
print("model",model)
在加载模型的代码中添加上,搭建神经网络的代码,再运行,就不会报错。
边栏推荐
- 【笔记】常见组合滤波电路
- 攻防演练中沙盘推演的4个阶段
- LeetCode刷题
- 5.过拟合,dropout,正则化
- 德总理称乌不会获得“北约式”安全保障
- Invalid V-for traversal element style
- Analysis of 8 classic C language pointer written test questions
- 【obs】Impossible to find entrance point CreateDirect3D11DeviceFromDXGIDevice
- How to insert highlighted code blocks in WPS and word
- [Yugong series] go teaching course 006 in July 2022 - automatic derivation of types and input and output
猜你喜欢
赞!idea 如何单窗口打开多个项目?
【愚公系列】2022年7月 Go教学课程 006-自动推导类型和输入输出
What does interface testing test?
新库上线 | CnOpenData中国星级酒店数据
Reentrantlock fair lock source code Chapter 0
How does starfish OS enable the value of SFO in the fourth phase of SFO destruction?
第一讲:链表中环的入口结点
5.过拟合,dropout,正则化
8道经典C语言指针笔试题解析
Reptile practice (VIII): reptile expression pack
随机推荐
英雄联盟胜负预测--简易肯德基上校
Leetcode brush questions
After going to ByteDance, I learned that there are so many test engineers with an annual salary of 40W?
Basic principle and usage of dynamic library, -fpic option context
基于卷积神经网络的恶意软件检测方法
Invalid V-for traversal element style
第四期SFO销毁,Starfish OS如何对SFO价值赋能?
新库上线 | CnOpenData中国星级酒店数据
华泰证券官方网站开户安全吗?
5G NR 系统消息
ReentrantLock 公平锁源码 第0篇
ABAP ALV LVC template
1293_FreeRTOS中xTaskResumeAll()接口的实现分析
Introduction to ML regression analysis of AI zhetianchuan
Reentrantlock fair lock source code Chapter 0
AI zhetianchuan ml novice decision tree
《因果性Causality》教程,哥本哈根大学Jonas Peters讲授
CVE-2022-28346:Django SQL注入漏洞
NVIDIA Jetson测试安装yolox过程记录
Semantic segmentation model base segmentation_ models_ Detailed introduction to pytorch