当前位置:网站首页>PyTorch(15)---模型保存和加载
PyTorch(15)---模型保存和加载
2022-08-02 14:07:00 【伏月三十】
模型保存和加载
模型保存
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
#模型:vgg16
vgg16=torchvision.models.vgg16(pretrained=False)
'''第一种:模型参数都保存'''
torch.save(vgg16,"vgg16_method1.pth")
'''第二种:只保存参数,保存成字典(官方推荐)'''
torch.save(vgg16.state_dict(),"vgg16_method2.pth")
'''陷阱1:使用自己搭建的网络(第一种)'''
class Demo(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model1=Sequential(
Conv2d(in_channels=3, out_channels=32, kernel_size=5, stride=1, padding=2, dilation=1, ),
MaxPool2d(kernel_size=2, ),
Conv2d(in_channels=32, out_channels=32, kernel_size=5, stride=1, padding=2, ),
MaxPool2d(kernel_size=2),
Conv2d(32, 64, 5, 1, 2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 10),
)
def forward(self,x):
x=self.model1(x)
return x
#模型2
demo=Demo()
torch.save(demo,"demo_method1.pth")
模型加载
import torch
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from model_save import *
'''方式1:(使用保存方式1),加载模型'''
model=torch.load("vgg16_method1.pth")
print(model)
print("------------------------------------------------------------")
'''方式2'''
#先把结构搞出来
vgg16=torchvision.models.vgg16(pretrained=False)
#再把字典形式的参数放进去(自己训练的参数!!!)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)
print("------------------------------------------------------------")
'''陷阱1:要把自己的网络模型放过来,但是不用实例化了 或者将模型保存的文件import过来 '''
medel1=torch.load("demo_method1.pth")
print(medel1)
边栏推荐
- 6. How to use the CardView production card layout effect
- 宝塔面板搭建小说CMS管理系统源码实测 - ThinkPHP6.0
- 想做好分布式架构?这个知识点一定要理解透彻
- Flink时间和窗口
- LLVM系列第十七章:控制流语句for
- UIWindow的makeKeyAndVisible不调用rootviewController 的viewDidLoad的问题
- 6.如何使用CardView制作卡片布局效果
- 内存申请(malloc)和释放(free)之下篇
- redis入门-1-redis概念和基础
- Enhanced Apktool reverse artifact
猜你喜欢
随机推荐
什么?都0202年了,你还不会屏幕适配?
MySQL知识总结 (二) 存储引擎
Ffmpeg交叉编译
1. What is RecyclerView
LLVM系列第九章:控制流语句if-else
NER(命名体识别)之 FLAT模型
Seq2Seq模型PyTorch版本
It is not allowed to subscribe with a(n) xxx multiple times.Please create a fresh instance of xxx
liunx下mysql遇到的简单问题
IllegalStateException: Room cannot verify the data integrity. Looks like you've changed schema but
spark写sql的方式
Spark_Core
PostgreSQL 性能谜题
电商项目常见连续登录,消费,日期等问题
十分钟带你入门Nodejs
Cannot figure out how to save this field into database. You can consider adding a type converter for
Flink实现Exactly Once
MySQL知识总结 (三) 索引
Win10不能启动WampServer图标呈橘黄色的解决方法
Redis-01-Nosql概述