当前位置:网站首页>torch.load()
torch.load()
2022-07-30 05:38:00 【向大厂出发】
1、torch.load()作用:用来加载torch.save() 保存的模型文件。
torch.load()先在CPU上加载,不会依赖于保存模型的设备。如果加载失败,可能是因为没有包含某些设备,比如你在gpu上训练保存的模型,而在cpu上加载,可能会报错,此时,需要使用map_location来将存储动态重新映射到可选设备上,比如map_location=torch.device('cpu'),意思是映射到cpu上,在cpu上加载模型,无论你这个模型从哪里训练保存的。
一句话:map_location适用于修改模型能在gpu上运行还是cpu上运行。如果map_location是可调用的,那么对于每个带有两个参数的序列化存储,它将被调用一次:storage和location。存储参数将是存储的初始反序列化,驻留在CPU上。每个序列化存储都有一个与之关联的位置标记,它标识保存它的设备,这个标记是传递给map_location的第二个参数。内置的位置标签是“cpu”为cpu张量和“cuda:device_id”(例如:device_id)。“cuda:2”)表示cuda张力。map_location应该返回None或一个存储。如果map_location返回一个存储,它将被用作最终的反序列化对象,已经移动到正确的设备。否则,torch.load()将退回到默认行为,就好像没有指定map_location一样。
如果map_location是一个torch.device对象或一个包含设备标签的字符串,它表示所有张量应该被加载的位置。
2、使用
torch.load(f, map_location=None, pickle_module=<module 'pickle' from '/opt/conda/lib/python3.6/pickle.py'>, **pickle_load_args)1)f – 类文件对象(必须实现read()、readline()、tell()和seek()),或包含文件名的字符串或 os.PathLike 对象
2)map_location – 一个函数、字符串或字典,指定如何重新映射存储位置torch.device
3)pickle_module – 用于解封元数据和对象的模块(必须与pickle_module用于序列化文件的模块匹配)
4)pickle_ load _args –(仅限 Python 3)传递给 and的可选关键字参数,例如 .pickle_module.load()pickle_module.Unpickler()errors=...3、例子
一般情况下,加载模型,主要用于预测新来的一组样本。预测的主要流程包括:输入数据——预处理——加载模型——预测得返回值(类别或者是属于某一类别的概率)
def predict(test_data, model_path, config):
‘’‘
input:
test_data:测试数据
model_path:模型的保存路径 model_path = './save/20201104_204451.ckpt'
output:
score:模型输出属于某一类别的概率
’‘’
data = process_data_for_predict(test_data)#预处理数据,使得数据格式符合模型输入形式
model = torch.load(model_path)#加载模型
score = model(data)#模型预测
return score #返回得分参考:
2) pytorch(一)模型加载函数torch.load()_凝眸伏笔的博客-CSDN博客_python torch.load
边栏推荐
- MySQL 有这一篇就够(呕心狂敲37k字,只为博君一点赞!!!)
- 倒计数(来源:Google Kickstart2020 Round C Problem A)(DAY 88)
- This dependency was not found:
- [Mysql] DATEDIFF function
- 1475. 商品折扣后的最终价格
- Internet (software) company project management software research report
- [Mysql] CONVERT function
- [GO语言基础] 一.为什么我要学习Golang以及GO语言入门普及
- mysql 中 in 的用法
- Arrange numbers (DAY90) dfs
猜你喜欢

【图像检测】基于灰度图像的积累加权边缘检测方法研究附matlab代码

MySql模糊查询大全

cnpm安装步骤

从字节码角度带你彻底理解异常中catch,return和finally,再也不用死记硬背了

Error: listen EADDRINUSE: address already in use 127.0.0.1:3000

从底层结构开始学习FPGA(6)----分布式RAM(DRAM,Distributed RAM)

ClickHouse data insert, update and delete operations SQL

postman 请求 post 调用 传 复合 json数据

手把手教你彻底卸载MySQL

Nacos 原理
随机推荐
The use of Conluce, an online document management system
MySQL索引常见面试题(2022版)
瑞吉外卖项目:新增菜品与菜品分页查询
idea设置自动带参数的方法注释(有效)
Navicat new database
pwn-ROP
图形镜像对称(示意图)
cmd (command line) to operate or connect to the mysql database, and to create databases and tables
解决没有配置本地nacos但是一直发生localhost8848连接异常的问题
429. N 叉树的层序遍历(两种解法)
mysql 时间字段默认设置为当前时间
倒计数(来源:Google Kickstart2020 Round C Problem A)(DAY 88)
MySQL-Explain详解
MYSQL-InnoDB的线程模型
从底层结构开始学习FPGA(6)----分布式RAM(DRAM,Distributed RAM)
PyCharm usage tutorial (more detailed, picture + text)
Error: npm ERR code EPERM
Solve the problem that the local nacos is not configured but the localhost8848 connection exception always occurs
[Mysql] CONVERT函数
cmd(命令行)操作或连接mysql数据库,以及创建数据库与表
https://pytorch.org/docs/stable/generated/torch.load.html?highlight=torch%20load#torch.load