当前位置:网站首页>网络模型的保存与读取

网络模型的保存与读取

2022-07-07 23:11:00 booze-J

网络模型的保存

方式1

使用示例代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d

# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)

# 保存方式1(不仅仅保存了结构还保存了网络模型中的一些参数) 模型结构+模型参数
torch.save(vgg16,"vgg16_method1.pth")

torch.save(vgg16,"vgg16_method1.pth")不仅仅保存了结构还保存了网络模型中的一些参数(保存了模型结构+模型参数)。

方式2

使用示例代码如下:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d

# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式2(将vgg16网络中的参数保存成python中的字典形式) 模型参数(官方推荐)
torch.save(vgg16.state_dict(),"vgg16_method2.pth")

torch.save(vgg16.state_dict(),"vgg16_method2.pth")将vgg16网络中的参数保存成python中的字典形式(保存模型参数(官方推荐)),相当于加载的时候需要先加载网络模型,再加载参数。

网络模型的读取

方式1

使用示例代码如下:

import torchvision
import torch

# 方式1 -》 保存方式1 加载模型+参数
model = torch.load("vgg16_method1.pth")
print("model",model)

前提是使用网络模型保存的方式1,直接使用model = torch.load("vgg16_method1.pth")便可以加载模型和参数。

方式2

使用示例代码如下:

import torchvision
import torch
# 加载网络模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 方式2 -》 保存方式2 加载参数
model = torch.load("vgg16_method2.pth")
print("model:\n",model)
# 将模型参数添加到模型中
vgg16.load_state_dict(model)

前提是使用网络模型保存的方式2,使用model = torch.load("vgg16_method2.pth")加载只能把模型参数加载出来,好需要将网络模型加载出来vgg16 = torchvision.models.vgg16(pretrained=False),然后把模型参数添加到模型中vgg16.load_state_dict(model)

网络模型的保存与读取的陷阱

自己写一个简单网络,然后使用方式1保存。
示例代码:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
# 搭建神经网络
class Booze(nn.Module):

    # 继承nn.Module的初始化
    def __init__(self):
        super(Booze, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    # 重写forward函数
    def forward(self,x):
        output = self.maxpool1(x)
        return output



obj = Booze()
# 保存网络模型
torch.save(obj,"obj_method1.pth")

运行代码,网络模型保存成功之后,我们加载一下网络模型试试看

import torchvision
import torch
model = torch.load("obj_method1.pth")
print("model",model)

结果发现:
在这里插入图片描述
直接这样加载模型会报错!那怎么解决呢?

import torchvision
import torch
from torch import nn
from torch.nn import MaxPool2d

# 搭建神经网络
class Booze(nn.Module):

    # 继承nn.Module的初始化
    def __init__(self):
        super(Booze, self).__init__()
        self.maxpool1 = MaxPool2d(kernel_size=3,ceil_mode=True)

    # 重写forward函数
    def forward(self,x):
        output = self.maxpool1(x)
        return output
# 陷阱1
model = torch.load("obj_method1.pth")
print("model",model)

在加载模型的代码中添加上,搭建神经网络的代码,再运行,就不会报错。
在这里插入图片描述

原网站

版权声明
本文为[booze-J]所创,转载请带上原文链接,感谢
https://blog.csdn.net/booze_/article/details/125554415