当前位置:网站首页>pytorch训练好的模型在加载和保存过程中的问题
pytorch训练好的模型在加载和保存过程中的问题
2022-07-06 08:27:00 【MAR-Sky】
在gpu上训练完成,在cpu上加载
torch.save(model.state_dict(), PATH)# 在gpu上训练后保存
# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cpu'))
在cpu上训练完成,在gpu上加载
torch.save(model.state_dict(), PATH)# 在gpu上训练后保存
# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cuda:0'))
在使用中需要注意的加载内容
当数据放入GPU,需要训练的模型也要放入GPU
''' data_loader:pytorch中加载数据 '''
for i, sample in enumerate(data_loader): # 对数据进行按批次遍历
image, target = sample # 每一批次加载返回值
if CUDA:
image = image.cuda() # 输入输出传入gpu
target = target.cuda()
# print(target.size)
optimizer.zero_grad() # 优化函数
output = mymodel(image)
mymodel.to(torch.device("cuda"))

多个gpu训练时的加载
参考:https://blog.csdn.net/weixin_43794311/article/details/120940090
import torch.nn as nn
mymodel = nn.DataParallel(mymodel)
pytorch中的nn模块使用nn.DataParallel将模型加载到多个GPU,需要注意,这种加载方式保存的权重参数会比不使用nn.DataParallel加载模型保存的权重参数的关键字前多一个"module."。是否使用nn.DataParallel加载模型,会导致下次再加载模型的时候可能会出现下图的问题,
当权重参数前面多一个“module."时,最简单的方式就是使用nn.DataParallel对模型加载,
边栏推荐
- The ECU of 21 Audi q5l 45tfsi brushes is upgraded to master special adjustment, and the horsepower is safely and stably increased to 305 horsepower
- On the day of resignation, jd.com deleted the database and ran away, and the programmer was sentenced
- Synchronized solves problems caused by sharing
- Vocabulary notes for postgraduate entrance examination (3)
- ROS编译 调用第三方动态库(xxx.so)
- Deep learning: derivation of shallow neural networks and deep neural networks
- Restore backup data on S3 compatible storage with br
- 【刷题】牛客网面试必刷TOP101
- 使用 TiDB Lightning 恢复 S3 兼容存储上的备份数据
- Precise query of tree tree
猜你喜欢

Online yaml to CSV tool
![[research materials] 2021 live broadcast annual data report of e-commerce - Download attached](/img/a6/74da2f44c7b6b22fed2f8e41a55988.jpg)
[research materials] 2021 live broadcast annual data report of e-commerce - Download attached

IOT -- interpreting the four tier architecture of the Internet of things

C language custom type: struct

matplotlib. Widgets are easy to use

Convolution, pooling, activation function, initialization, normalization, regularization, learning rate - Summary of deep learning foundation
![[research materials] 2021 Research Report on China's smart medical industry - Download attached](/img/c8/a205ddc2835c87efa38808cf31f59e.jpg)
[research materials] 2021 Research Report on China's smart medical industry - Download attached
![[cloud native topic -45]:kubesphere cloud Governance - Introduction and overall architecture of enterprise container platform based on kubernetes](/img/ac/773ce8ee7f380df19edf8373250608.jpg)
[cloud native topic -45]:kubesphere cloud Governance - Introduction and overall architecture of enterprise container platform based on kubernetes

What is the use of entering the critical point? How to realize STM32 single chip microcomputer?

PLT in Matplotlib tight_ layout()
随机推荐
Upgrade tidb with tiup
根据csv文件某一列字符串中某个数字排序
Online yaml to CSV tool
704 二分查找
Day29-t77 & t1726-2022-02-13-don't answer by yourself
Analysis of pointer and array written test questions
[research materials] 2021 Research Report on China's smart medical industry - Download attached
TiDB备份与恢复简介
[MySQL] lock
ESP series pin description diagram summary
ESP系列引脚說明圖匯總
升级 TiDB Operator
China vanadium battery Market Research and future prospects report (2022 Edition)
vulnhub hackme: 1
On the day of resignation, jd.com deleted the database and ran away, and the programmer was sentenced
[luatos-air551g] 6.2 repair: restart caused by line drawing
Introduction to number theory (greatest common divisor, prime sieve, inverse element)
LDAP Application Section (4) Jenkins Access
Bottom up - physical layer
[2022 广东省赛M] 拉格朗日插值 (多元函数极值 分治NTT)