当前位置:网站首页>pytorch训练好的模型在加载和保存过程中的问题

pytorch训练好的模型在加载和保存过程中的问题

2022-07-06 08:27:00 MAR-Sky

在gpu上训练完成,在cpu上加载

torch.save(model.state_dict(), PATH)# 在gpu上训练后保存

# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cpu'))

在cpu上训练完成,在gpu上加载

torch.save(model.state_dict(), PATH)# 在gpu上训练后保存

# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cuda:0'))

在使用中需要注意的加载内容

当数据放入GPU,需要训练的模型也要放入GPU

''' data_loader:pytorch中加载数据 '''
 for i, sample in enumerate(data_loader):  # 对数据进行按批次遍历
     image, target = sample  # 每一批次加载返回值
     if CUDA:
         image = image.cuda()   # 输入输出传入gpu
         target = target.cuda()
     # print(target.size)
     optimizer.zero_grad()     # 优化函数
     output = mymodel(image)

mymodel.to(torch.device("cuda"))

在这里插入图片描述

多个gpu训练时的加载

参考:https://blog.csdn.net/weixin_43794311/article/details/120940090

import torch.nn as nn
mymodel = nn.DataParallel(mymodel)

pytorch中的nn模块使用nn.DataParallel将模型加载到多个GPU,需要注意,这种加载方式保存的权重参数会比不使用nn.DataParallel加载模型保存的权重参数的关键字前多一个"module."。是否使用nn.DataParallel加载模型,会导致下次再加载模型的时候可能会出现下图的问题,
在这里插入图片描述
当权重参数前面多一个“module."时,最简单的方式就是使用nn.DataParallel对模型加载,

原网站

版权声明
本文为[MAR-Sky]所创,转载请带上原文链接,感谢
https://blog.csdn.net/weixin_43794311/article/details/125517326