当前位置:网站首页>pytorch的模型保存加载和继续训练
pytorch的模型保存加载和继续训练
2022-06-22 19:27:00 【Weiyaner】
随着现在模型越来越大,一次性训练完模型在低算力平台也越来越难以实现,因此很有必要在训练过程中保存模型,以便下次之前训练的基础上进行继续训练,节约时间。代码如下:
导包
import torch
from torch import nn
import numpy as np
定义模型
定义一个三层的MLP分类模型
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(64, 32)
self.linear1 = nn.Linear(32, 10)
self.relu = nn.ReLU()
def forward(self, x):
x = self.linear(x)
x = self.relu(x)
x = self.linear1(x)
return x
## 随机生成2组带标签的数据
rand1 = torch.rand((100, 64)).to(torch.float)
label1 = np.random.randint(0, 10, size=100)
label1 = torch.from_numpy(label1).to(torch.long)
rand2 = torch.rand((100, 64)).to(torch.float)
label2 = np.random.randint(0, 10, size=100)
label2 = torch.from_numpy(label2).to(torch.long)
model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()
## 训练10个epoch
epoch = 10
for i in range(epoch):
output = model(rand1)
my_loss = loss(output, label1)
optimizer.zero_grad()
my_loss.backward()
optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
结果如下:记下这些loss值,观察下次继续训练的初始loss
epoch:0 loss:2.3494179248809814
epoch:1 loss:2.287858009338379
epoch:2 loss:2.2486231327056885
epoch:3 loss:2.2189149856567383
epoch:4 loss:2.193182945251465
epoch:5 loss:2.167125940322876
epoch:6 loss:2.140075206756592
epoch:7 loss:2.1100614070892334
epoch:8 loss:2.0764594078063965
epoch:9 loss:2.0402779579162598
模型保存
采用torch.save函数保存模型,一般分为两种模式,分别是简单的保存所有参数,第二种是保存各部分参数,到一个字典结构里面。
# 保存模型的整体参数
save_path = r'model_para/'
torch.save(model, save_path+'model_full.pth')
保存模型参数,优化器参数和epoch情况。
def save_model(save_path, epoch, optimizer, model):
torch.save({
'epoch': epoch+1,
'optimizer_dict': optimizer.state_dict(),
'model_dict': model.state_dict()},
save_path)
print("model save success")
save_model(save_path+'model_dict.pth',epoch, optimizer, model)
加载模型
对于保存的pth参数文件,使用torch.load进行加载,代码如下:
def load_model(save_name, optimizer, model):
model_data = torch.load(save_name)
model.load_state_dict(model_data['model_dict'])
optimizer.load_state_dict(model_data['optimizer_dict'])
print("model load success")
观察当前训练模型的权重参数
print(model.state_dict()['linear.weight'])
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
命名一个新模型,加载之前保存的参数文件,并打印出层参数
new_model = MyModel()
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
load_model(save_path+'model_dict.pth', new_optimizer, new_model)
print(new_model.state_dict()['linear.weight'])
可以看出新模型和当前模型的参数一致,说明参数加载成功。
model load success
tensor([[-0.0215, 0.0299, -0.0255, ..., -0.0997, -0.0899, 0.0499],
[-0.0113, -0.0974, 0.1020, ..., 0.0874, -0.0744, 0.0801],
[ 0.0471, 0.1373, 0.0069, ..., -0.0573, -0.0199, -0.0654],
...,
[ 0.0693, 0.1900, 0.0013, ..., -0.0348, 0.1541, 0.1372],
[ 0.1672, -0.0086, 0.0189, ..., 0.0926, 0.1545, 0.0934],
[-0.0773, 0.0645, -0.1544, ..., -0.1130, 0.0213, -0.0613]])
继续训练
在新模型加载原来模型参数的基础上,继续训练,观察loss值,是在之前训练的最终loss,继续下降,说明模型继续训练成功。
epoch = 10
for i in range(epoch):
output = new_model(rand1)
my_loss = loss(output, label1)
new_optimizer.zero_grad()
my_loss.backward()
new_optimizer.step()
print("epoch:{} loss:{}".format(i, my_loss))
epoch:0 loss:2.0036799907684326
epoch:1 loss:1.965193271636963
epoch:2 loss:1.924098253250122
epoch:3 loss:1.881495714187622
epoch:4 loss:1.835693359375
epoch:5 loss:1.7865667343139648
epoch:6 loss:1.7352293729782104
epoch:7 loss:1.6832704544067383
epoch:8 loss:1.6308385133743286
epoch:9 loss:1.5763107538223267
数据分布不一致带来的问题
同样,在这里我发现一个问题,因为之前随机产生了2组数据,之前模型训练使用的rand1,这里只有继续训练rand1,之前模型的参数才有效,如果使用rand2,模型相当于从0训练(如下loss),这是因为,两组数据都是随机生成的,数据分布几乎不一样,所以上一组数据训练的模型在第二组数据几乎无效。
epoch:0 loss:2.523787498474121
epoch:1 loss:2.469816207885742
epoch:2 loss:2.4141526222229004
epoch:3 loss:2.379054069519043
epoch:4 loss:2.3563807010650635
epoch:5 loss:2.319946765899658
epoch:6 loss:2.271805763244629
epoch:7 loss:2.2274367809295654
epoch:8 loss:2.186885118484497
epoch:9 loss:2.144239902496338
但是在真实情况中,由于batch数据都是假设同一分布,所以不用考虑这个问题,
那么以上,就完成了pytorch的模型保存,加载和继续训练的三种重要过程,希望能够帮到您!!!
祝您训练愉快。
边栏推荐
- 如何计算 R 中的基尼系数(附示例)
- 一张图解码 OpenCloudOS 社区开放日
- Software testing - Test Case Design & detailed explanation of test classification
- Resolved: can there be multiple auto incrementing columns in a table
- Kotlin1.6.20新功能Context Receivers使用技巧揭秘
- [proteus simulation] H-bridge drive DC motor composed of triode + key forward and reverse control
- Huawei cloud releases Latin American Internet strategy
- Nestjs 集成 config module 与 nacos 实现配置化统一
- 【深入理解TcaplusDB技术】单据受理之建表审批
- 天,靠八股文逆袭了啊
猜你喜欢

MySQL高级(二)
Gradle Build Cache引发的Task缓存编译问题

Xunrui CMS custom data interface PHP executable code

R 语言USArrests 数据集可视化

AAAI 2022 | 传统GAN修改后可解释,并保证卷积核可解释性和生成图像真实性

Résolu: peut - on avoir plus d'une colonne auto - incrémentale dans un tableau

Possible security vulnerabilities in NFT
mysql8.0忘记密码的详细解决方法

AAAI 2022 | traditional Gan can be interpreted after modification, and the interpretability of convolution kernel and the authenticity of generated image are guaranteed

Teach you how to create SSM project structure in idea
随机推荐
[deeply understand tcapulusdb technology] create a game area for document acceptance
How to realize @ person function in IM instant messaging
[observation] innovation in the software industry has entered a "new cycle". How can we make a new start in the changing situation?
LORA技术---LoRa信号从数据流变为LoRa扩频信号,再从射频信号通过解调变为数据
Overview of common loss functions for in-depth learning: basic forms, principles and characteristics
农产品期货开户
R language organdata dataset visualization
Container container runtime (2): which is better for you, yum installation or binary installation?
MySQL高级(二)
The real king of cache
Introduction of neural networks for Intelligent Computing (Hopfield network DHNN, CHNN)
软件测试——测试用例设计&测试分类详解
Web technology sharing | [Gaode map] to realize customized track playback
AAAI 2022 | traditional Gan can be interpreted after modification, and the interpretability of convolution kernel and the authenticity of generated image are guaranteed
R语言 co2数据集 可视化
Solutions to Oracle system/ user locking
The road to systematic construction of geek planet business monitoring and alarm system
MySQL foundation - constraints
【深入理解TcaplusDB技术】单据受理之建表审批
Cryptography series: certificate format representation of PKI X.509