当前位置:网站首页>Dataset 和 Dataloader数据加载
Dataset 和 Dataloader数据加载
2022-07-25 09:26:00 【zzh1370894823】
初学pytorch, 一直分不清数据是如何加载的,分不清Dataset 和 Dataloader的联系。
utils包含Dataset和Dataloader两个类。自定义数据集需要继承这个类,并实现两个函数,一个是__len__,另一个是__getitem__,前者提供数据的大小,后者通过索引获取数据和标签。
__getitem__一次只能获取一个数据,所以需要通过Dataloader来定义一个新的迭代器,实现batch读取。
下面举一个直观的小例子来搞明白是怎么回事!
import torch
from torch.utils import data
import numpy as np
''' 数据集: label:data 0:[1, 2], 1:[3, 4], 0:[2, 1], 1:[3, 4], 2:[4, 5] '''
class TextDataset(data.Dataset): # 继承Dataset
def __init__(self):
self.Data = np.asarray([[1, 2], [3, 4], [2, 1], [3, 4], [4, 5]]) # 一些由2维向量表示的数据集
self.Label = np.asarray([0, 1, 0, 1, 2]) # 数据集对应的标签
def __getitem__(self, item):
text = torch.from_numpy(self.Data[item]) # 把numpy转化为Tensor
label = torch.tensor(self.Label[item])
return text, label
def __len__(self):
return len(self.Data)
# 获取数据集中数据
Test = TextDataset()
print(Test[3]) # 相当于调用getitem(3)
# 输出:
# (tensor([3, 4], dtype=torch.int32), tensor(1, dtype=torch.int32))
以上数据以tuple 返回,每次只返回一个样本,如果希望批量处理batch,需要用到DataLoader
test_loader = data.DataLoader(Test,batch_size=2,shuffle=False)
for i, traindata in enumerate(test_loader):
print("i:",i)
Data, Label = traindata
print("data:", Data) # 其中一个data包含2组数据,一个batch大小
print("label:", Label)
# 输出:
# i: 0
# data: tensor([[1, 2],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 1
# data: tensor([[2, 1],
# [3, 4]], dtype=torch.int32)
# label: tensor([0, 1], dtype=torch.int32)
# i: 2
# data: tensor([[4, 5]], dtype=torch.int32)
# label: tensor([2], dtype=torch.int32)
其中一个data变成原来两组data的组成
相应的label也变成了原来对应的两个label的组成
.
参考于吴茂贵的python深度学习
边栏推荐
- 手持振弦VH501TC采集仪传感器的连接与数据读取
- Download and installation of QT 6.2
- Introduction to armv8 general timer
- See how a junior student of double non-2 (0 Internship) can get an offer from Alibaba and Tencent
- Fundamentals of C language
- MVC三层架构理解
- 概率论与数理统计 4 Continuous Random Variables and Probability Distributions(连续随机变量与概率分布)(上篇)
- js数字千位分割的常用方法
- Excel导入导出源码分析
- JDBC操作数据库详解
猜你喜欢

js利用requestAnimationFrame实时检测当前动画的FPS帧率

rospy Odometry天坑小计

CCF 201512-4 delivery

【建议收藏】靠着这些学习方法,我入职了世界五百强——互联网时代的“奇技淫巧”

Probabilistic robot learning notes Chapter 2

看一个双非二本(0实习)大三学生如何拿到阿里、腾讯的offer

CCF 201604-2 Tetris

message from server: “Host ‘xxx.xxx.xxx.xxx‘ is not allowed to connect to this MySQL server“

Swift creates weather app

MLX90640 红外热成像仪测温模块开发说明
随机推荐
ADC introduction
Introduction to low power consumption and UPF
工程监测无线中继采集仪和无线网络的优势
[recommended collection] with these learning methods, I joined the world's top 500 - the "fantastic skills and extravagance" in the Internet age
Reflection 反射
JS uses requestanimationframe to detect the FPS frame rate of the current animation in real time
CCF 201512-3 drawing
Fundamentals of C language
Common methods of JS digital thousand bit segmentation
Filter过滤器详解(监听器以及它们的应用)
[Android studio] batch data import to Android local database
入住阿里云MQTT物联网平台
JDBC操作数据库详解
[nearly 10000 words dry goods] don't let your resume don't match your talent -- teach you to make the most suitable resume by hand
Introduction to arm GIC
OC -- Inheritance and polymorphic and pointer
MLX90640 红外热成像传感器测温模块开发笔记(二)
C函数不加括号的教训
[deployment of deep learning model] deploy the deep learning model using tensorflow serving + tornado
ThreadLocal&Fork/Join