当前位置:网站首页>【Pytorch学习笔记】10.如何快速创建一个自己的Dataset数据集对象(继承Dataset类并重写对应方法)
【Pytorch学习笔记】10.如何快速创建一个自己的Dataset数据集对象(继承Dataset类并重写对应方法)
2022-08-05 05:15:00 【takedachia】
我们在做实际项目时,经常会用到自己的数据集,需要将它构造成一个Dataset对象让pytorch能读取使用。
我们之前经常调用 torchvision 库中的数据集对象直接获得常用数据集,如:torchvision.datasets.FashionMNIST()
,这样获得的一个Dataset对象属于 torch.utils.data.Dataset 类。获得Dataset对象后传入DataLoader就可以加载批量数据参与训练了。
如果我们有自己的数据集该怎么定制一个自己的Dataset呢?
继承Dataset类,并重写对应方法创建自己的Dataset
我们看官方文档:
文档中描述了构建一个自己的dataset,需要重写魔法方法__getitem__()
来指定索引访问数据的方法,同时需要重写__len__()
来获取数据集的长度(数量)。
我们直接看个简单的例子,就非常一目了然了:
# 创建数据集对象
class text_dataset(Dataset): #需要继承Dataset类
def __init__(self, words, labels):
self.words = words
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
word = self.words[idx]
return word, label
上面我们创建一个数据集对象,对一个单词指定一个情感的标签。
words传入的是各个单词,为一个List。
labels则是各个单词对应的标签,为一个List。
- 在__init__中,我们将传入的序列指定为类的属性
- 在__len__中,我们设定数据集的长度
- 在__getitem__,我们使用参数idx,指定索引访问元素的方法,并指定返回元素
我们有如下数据源:
words = ['Happy', 'Amazing', 'Sad', 'Unhapy', 'Glum']
labels = ['Positive', 'Positive', 'Negative', 'Negative', 'Negative']
dataset_words = text_dataset(words, labels)
dataset_words[0]
# 返回:
# ('Happy', 'Positive')
就可以传入我们创建的dataset,实例化一个新的dataset。可通过下标访问数据。
接着就可以传入一个DataLoader:
train_iter = DataLoader(dataset_words, batch_size=2)
X, y = next(iter(train_iter))
X, y
# 返回:
# (('Happy', 'Amazing'), ('Positive', 'Positive'))
这样,一个简单的Dataset就创建好了。
下面讲一个创建图片数据集的实例。
实例:用自己的图片数据集创建
例子使用的是 动手学深度学习 中的树叶分类项目,地址:https://www.kaggle.com/competitions/classify-leaves
图片数据集长什么样
我们把数据集解压后发现下面一个子文件夹image里存放了共27153张图片,其中标号前18353张图片为训练集,后8800张图片为测试集(测试集没有给label)。
训练集的标签信息在train.csv中,有176类。
我们发现图片的信息和label信息没有直接对应起来,最好是一个图片张量对应一个label类才行。
所以这样的数据集需要处理一下才能读入Dataset中。
但是!
这里我先把这些jpg文件重命名一下,文件名不满5位数的前面填0,因为届时用torchvision.datasets.ImageFolder读取文件是按字符串顺序读取的(ImageFolder的著名坑)。改成如图形式:
文件批量重命名代码:
# 先给文件名称重命名一下,数字不满5位的一律补全0,因为届时用ImageFolder读取是按字符串顺序读取的
# 即 3.jpg → 00003.jpg
import os
path = '../classify-leaves/images'
file_list = os.listdir(path)
for file in file_list:
front, end = file.split('.') # 取得文件名和后缀
front = front.zfill(5) # 文件名补0,5表示补0后名字共5位
new_name = '.'.join([front, end])
# print(new_name)
os.rename(path + '\\' + file, path + '\\' + new_name)
数据预处理
我们先使用torchvision.datasets.ImageFolder把image下的图片读入一个临时的Dataset,data_images
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
train_augs = transforms.Compose([
transforms.Resize(224),
transforms.ToTensor()
])
data_images = ImageFolder(root='../classify-leaves', transform=train_augs)
再读取训练集的标签信息。
train_csv = pd.read_csv('../classify-leaves/train.csv')
print(len(train_csv))
train_csv
我们知道类别信息届时在训练时是需要转成独热编码的,所以需要先把类别信息的label转成类别号。
train_csv.label.unique()可得到所有类别名,其为一个有序的numpy数组,可通过查询的方法来取得索引号,索引号就可以当作类别号。
# 获取某个元素的索引的方法:
# 这个class_to_num可以存起来,之后可作为类别号到类别名称的映射
class_to_num = train_csv.label.unique()
np.where(class_to_num == 'quercus_montana')[0][0] # 取两次[0]取到序号
建立类别号信息:
(上面这个class_to_num可以存起来,之后可作为类别号到类别名称的映射)
train_csv['class_num'] = train_csv['label'].apply(lambda x: np.where(class_to_num == x)[0][0])
train_csv
创建Dataset
# 创建数据集对象 —— leaf
class leaf_dataset(Dataset): # 需要继承Dataset类
def __init__(self, imgs, labels):
self.imgs = imgs
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
label = self.labels[idx]
data = self.imgs[idx][0] # 届时传入一个ImageFolder对象,需要取[0]获取数据,不要标签
return data, label
imgs = data_images
labels = train_csv.class_num
这里将之前用ImageFolder建立的临时Dataset直接作为参数imgs,因为ImageFolder取到图片数据需要再取个0(取1则是label,在这个例子中是“image”),所以在写__getitem__时在取data时后面加个[0]。
下面创建Dataset,传入DataLoader,并显示一下数据:
Leaf_dataset = leaf_dataset(imgs=imgs, labels=labels)
train_iter = DataLoader(dataset=Leaf_dataset, batch_size=256, shuffle=True)
X, y = next(iter(train_iter))
X[0].shape, y[0]
这里,细心的同学可能会问:imgs长度是27153,labels长度是18353:
这样不等长传入一个数据集没问题吗?
事实上一对不等长序列传入Dataset会有本身的问题,但传入DataLoader之后会自动筛掉不等长的部分,最后载入的数据长度依然会是训练集的18353。
还是建议先把Dataset整理一下,可以使用torch.utils.data.Subset方法直接取前18353个元素(也可以在Dataset类内自己修改成想要的样子):
indices = range(len(labels))
Leaf_dataset_tosplit = torch.utils.data.Subset(Leaf_dataset, indices)
最后展示一下图片:
# 展示一下
toshow = [torch.transpose(X[i],0,2) for i in range(16)]
def show_images(imgs, num_rows, num_cols, scale=2):
figsize = (num_cols * scale, num_rows * scale)
_, axes = plt.subplots(num_rows, num_cols, figsize=figsize)
for i in range(num_rows):
for j in range(num_cols):
axes[i][j].imshow(imgs[i * num_cols + j])
axes[i][j].axes.get_xaxis().set_visible(False)
axes[i][j].axes.get_yaxis().set_visible(False)
return axes
show_images(toshow, 2, 8, scale=2)
总结
我们常用继承 torch.utils.data.Dataset 类的方法来构造一个自己的Dataset,同时需要重写以下几个魔法方法:
- 在__init__中,将传入的数据序列指定为类的属性
- 在__len__中,设定数据集的长度
- 在__getitem__,使用参数idx,指定索引访问元素的方法,并指定返回元素
之后就可以传入DataLoader进行读取使用了。
(本文所用代码也可看我的Github)
边栏推荐
猜你喜欢
Using pip to install third-party libraries in Pycharm fails to install: "Non-zero exit code (2)" solution
【过一下 17】pytorch 改写 keras
【过一下14】自习室的一天
将照片形式的纸质公章转化为电子公章(不需要下载ps)
【NFT网站】教你制作开发NFT预售网站官网Mint作品
el-pagination左右箭头替换成文字上一页和下一页
解决:Unknown column ‘id‘ in ‘where clause‘ 问题
软件设计 实验四 桥接模式实验
DOM and its applications
【过一下3】卷积&图像噪音&边缘&纹理
随机推荐
DOM及其应用
学习总结week3_4类与对象
day8字典作业
flink实例开发-batch批处理实例
鼠标放上去变成销售效果
实现跨域的几种方式
Flink 状态与容错 ( state 和 Fault Tolerance)
第四讲 反向传播随笔
基于Flink CDC实现实时数据采集(三)-Function接口实现
RecycleView和ViewPager2
拿出接口数组对象中的所有name值,取出同一个值
学习总结week3_1函数
IDEA 配置连接数据库报错 Server returns invalid timezone. Need to set ‘serverTimezone‘ property.
【NFT开发】设计师无技术基础保姆级开发NFT教程在Opensea上全套开发一个NFT项目+构建Web3网站
Matplotlib(二)—— 子图
Oracle压缩表修改字段的处理方法
解决:Unknown column ‘id‘ in ‘where clause‘ 问题
The difference between the operators and logical operators
SQL(二) —— join窗口函数视图
[Go through 4] 09-10_Classic network analysis