当前位置:网站首页>Dataset 和 Dataloader数据加载
Dataset 和 Dataloader数据加载
2022-07-25 09:26:00 【zzh1370894823】
初学pytorch, 一直分不清数据是如何加载的,分不清Dataset 和 Dataloader的联系。
utils包含Dataset和Dataloader两个类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据的大小,后者通过索引获取数据和标签。
__getitem__一次只能获取一个数据,所以需要通过Dataloader来定义一个新的迭代器,实现batch读取。
下面举一个直观的小例子来搞明白是怎么回事!
import torch
from torch.utils import data
import numpy as np
''' 数据集: label:data 0:[1, 2], 1:[3, 4], 0:[2, 1], 1:[3, 4], 2:[4, 5] '''
class TextDataset(data.Dataset): # 继承Dataset
def __init__(self):
self.Data = np.asarray([[1, 2], [3, 4], [2, 1], [3, 4], [4, 5]]) # 一些由2维向量表示的数据集
self.Label = np.asarray([0, 1, 0, 1, 2]) # 数据集对应的标签
def __getitem__(self, item):
text = torch.from_numpy(self.Data[item]) # 把numpy转化为Tensor
label = torch.tensor(self.Label[item])
return text, label
def __len__(self):
return len(self.Data)
# 获取数据集中数据
Test = TextDataset()
print(Test[3]) # 相当于调用getitem(3)
# 输出:
# (tensor([3, 4], dtype=torch.int32), tensor(1, dtype=torch.int32))
以上数据以tuple 返回,每次只返回一个样本,如果希望批量处理batch,需要用到DataLoader
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False)
for i, traindata in enumerate(test_loader):
print("i:",i)
Data, Label = traindata
print("data:", Data) # 其中一个data包含2组数据,一个batch大小
print("label:", Label)
# 输出:
# i: 0
# data: tensor([[1, 2],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 1
# data: tensor([[2, 1],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 2
# data: tensor([[4, 5]], dtype=torch.int32)
# label: tensor([2], dtype=torch.int32)
其中一个data变成原来两组data的组成
相应的label也变成了原来对应的两个label的组成
.
参考于吴茂贵的python深度学习
边栏推荐
- Detailed explanation of MySQL database
- vant问题记录
- 无线振弦采集仪参数配置工具的设置
- vscode插件开发
- [recommended collection] with these learning methods, I joined the world's top 500 - the "fantastic skills and extravagance" in the Internet age
- SSM整合(简单的图书管理系统来整合SSM)
- CCF 201509-4 Expressway
- [necessary for growth] Why do I recommend you to write a blog? May you be what you want to be in years to come.
- 腾讯云之错误[100007] this env is not enable anonymous login
- 拷贝过来老的项目变成web项目
猜你喜欢

Introduction to low power consumption and UPF

小程序分享功能
![[RNN] analyze the RNN from rnn- (simple|lstm) to sequence generation, and then to seq2seq framework (encoder decoder, or seq2seq)](/img/6e/da80133e05b18c87d7167c023b6c93.gif)
[RNN] analyze the RNN from rnn- (simple|lstm) to sequence generation, and then to seq2seq framework (encoder decoder, or seq2seq)

emmet语法速查 syntax基本语法部分

线程池的设计和原理

无线振弦采集仪的使用常见问题

ROS分布式操作--launch文件启动多个机器上的节点

【成长必备】我为什么推荐你写博客?愿你多年以后成为你想成为的样子。

Mlx90640 infrared thermal imager temperature measurement module development notes (4)

~2 CCF 2022-03-1 uninitialized warning
随机推荐
手持振弦VH501TC采集仪传感器的连接与数据读取
Pnpm Brief
Introduction to Verdi Foundation
nodejs版本升级或切换的常用方式
VS无线振弦采集仪蓝牙功能的使用
【建议收藏】靠着这些学习方法,我入职了世界五百强——互联网时代的“奇技淫巧”
工程监测无线中继采集仪和无线网络的优势
About student management system (registration, login, student side)
salt常见问题
CentOs安装redis
Loam transformtoend function integrating IMU details
message from server: “Host ‘xxx.xxx.xxx.xxx‘ is not allowed to connect to this MySQL server“
拷贝过来老的项目变成web项目
How Android uses ADB command to view application local database
Introduction to armv8 general timer
JSP details
LOAM 融合 IMU 细节之 TransformToEnd 函数
ThreadLocal&Fork/Join
CCF 201512-3 drawing
NLM5系列无线振弦传感采集仪的工作模式及休眠模式下状态