当前位置:网站首页>pytorch训练好的模型在加载和保存过程中的问题
pytorch训练好的模型在加载和保存过程中的问题
2022-07-06 08:27:00 【MAR-Sky】
在gpu上训练完成,在cpu上加载
torch.save(model.state_dict(), PATH)# 在gpu上训练后保存
# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cpu'))
在cpu上训练完成,在gpu上加载
torch.save(model.state_dict(), PATH)# 在gpu上训练后保存
# 在cpu的模型上加载使用
model.load_state_dict(torch.load(PATH, map_location='cuda:0'))
在使用中需要注意的加载内容
当数据放入GPU,需要训练的模型也要放入GPU
''' data_loader:pytorch中加载数据 '''
for i, sample in enumerate(data_loader): # 对数据进行按批次遍历
image, target = sample # 每一批次加载返回值
if CUDA:
image = image.cuda() # 输入输出传入gpu
target = target.cuda()
# print(target.size)
optimizer.zero_grad() # 优化函数
output = mymodel(image)
mymodel.to(torch.device("cuda"))
多个gpu训练时的加载
参考:https://blog.csdn.net/weixin_43794311/article/details/120940090
import torch.nn as nn
mymodel = nn.DataParallel(mymodel)
pytorch中的nn模块使用nn.DataParallel将模型加载到多个GPU,需要注意,这种加载方式保存的权重参数会比不使用nn.DataParallel加载模型保存的权重参数的关键字前多一个"module."。是否使用nn.DataParallel加载模型,会导致下次再加载模型的时候可能会出现下图的问题,
当权重参数前面多一个“module."时,最简单的方式就是使用nn.DataParallel对模型加载,
边栏推荐
- 2022.02.13 - NC001. Reverse linked list
- leetcode刷题 (5.31) 字符串
- China vanadium battery Market Research and future prospects report (2022 Edition)
- 2022.02.13 - NC003. Design LRU cache structure
- Vocabulary notes for postgraduate entrance examination (3)
- [research materials] 2021 live broadcast annual data report of e-commerce - Download attached
- Upgrade tidb operator
- hcip--mpls
- ESP series pin description diagram summary
- 指针进阶---指针数组,数组指针
猜你喜欢
2022 Inner Mongolia latest water conservancy and hydropower construction safety officer simulation examination questions and answers
2022.02.13 - NC004. Print number of loops
Roguelike游戏成破解重灾区,如何破局?
Bottom up - physical layer
Analysis of pointer and array written test questions
Summary of phased use of sonic one-stop open source distributed cluster cloud real machine test platform
matplotlib. Widgets are easy to use
Beijing invitation media
Résumé des diagrammes de description des broches de la série ESP
CISP-PTE实操练习讲解
随机推荐
Precise query of tree tree
2022.02.13 - NC002. sort
备份与恢复 CR 介绍
logback1.3. X configuration details and Practice
使用 BR 备份 TiDB 集群数据到兼容 S3 的存储
2022.02.13 - 238. Maximum number of "balloons"
【Nvidia开发板】常见问题集 (不定时更新)
sys.argv
Leetcode question brushing (5.28) hash table
Bottom up - physical layer
[MySQL] database stored procedure and storage function clearance tutorial (full version)
China vanadium battery Market Research and future prospects report (2022 Edition)
VMware 虚拟化集群
Huawei cloud OBS file upload and download tool class
Introduction to number theory (greatest common divisor, prime sieve, inverse element)
[MySQL] log
Online yaml to CSV tool
从表中名称映射关系修改视频名称
How to use information mechanism to realize process mutual exclusion, process synchronization and precursor relationship
Pyqt5 development tips - obtain Manhattan distance between coordinates