当前位置:网站首页>复杂数据没头绪?
复杂数据没头绪?
2022-06-26 23:31:00 【昇思MindSpore】

在深度学习模型的训练过程中,数据集是起着至关重要作用的。然而,由于任务的复杂性,深度学习模型的输入数据也有着各种各样的形式,深度学习模型搭建的过程中,如果遇到特别复杂的数据,研究者可能要花费大半的时间在数据集的预处理(包括清洗、加载等过程)中。因此,高效的加载数据集,能给研究者构建一套高效的开发流程。使用过PyTorch的读者都知道,PyTorch框架为我们提供了一套极其便利且高效率的自定义数据加载的接口。用户只需要简单的继承torch.utils.data.Dataset并且在get_item函数和__len__函数,再利用Dataloader进行封装,就可以很简单的实现数据集的自动化加载流程(个人认为设置PyTorch在数据层面上做的超级好的一个点)。
● MindSpore数据集加载简介 ●
在MindSpore中,mindspore.dataset里面的函数为我们提供了大量的数据集专有加载算子,这些算子经过优化,拥有较好的数据集加载性能。但是,由于MindSpore本身的数据加载都是在C语言层面完成的,用户很难感知到内部进行的具体操作,特别是针对coco这一类较为复杂的数据集时(就是比较黑洞,很难自己掌握)。由于笔者是一个很喜欢把模型训练的每一步都抓在自己手里的一个人,因此除了cifar10、cifar100、imagefolder等经典的数据(结构)时,尽量都希望自己完成数据集的加载流程,以便更好的了解模型模型和数据集。因此,这篇博客将会主要介绍如何使用MindSpore自定自定义类似PyTorch范式的数据集加载流程。
● mindspore.dataset.Generator
Dataset ●
区别用PyTorch,MindSpore并不能像继承Dataset来完成数据集的构建,但是MindSpore为用户提供了一个类似于DataLoader的数据集封装接口。用户可以通过自定义object对象的数据集对象,然后使用GeneratorDataset进行封装,接下来我将以自定义cifar10和imagenet数据集来简单展示使用GeneratorDataset接口的方法。
● 自定义cifar10数据集 ●
分析格式
在定义数据集之前,我们首先要做的就是数据集的格式分析。在cifar官网中,我们可以得知数据集的基本格式,还可以通过已有的博客,查看读取cifar10的代码样例。如下图所示是cifar-10-batches-py数据集的目录文件,这里我们主要是关注data_batch和test_batch。

加载数据
这里我主要以torchvision中的cifar10数据加载为例,说明构建cifar10数据集的方法。
train_list = [
['data_batch_1', 'c99cafc152244af753f735de768cd75f'],
['data_batch_2', 'd4bba439e000b95fd0a9bffe97cbabec'],
['data_batch_3', '54ebc095f3ab1f0389bbae665268c751'],
['data_batch_4', '634d18415352ddfa80567beed471001a'],
['data_batch_5', '482c414d41f54cd18b22e5b47cb7c3cb'],
]
test_list = [
['test_batch', '40351d587109b95175f43aff81a1287e'],
]
...
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
...
for file_name, checksum in downloaded_list:
file_path = os.path.join(self.root, self.base_folder, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
"""可以很容易理解到,数据集文件里面有一个"data"和一个"label"键,分别拿出来就好"""
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC构建cifar10数据集并且完成预处理
由于cifar10读取进来以后已经是数据形式,因此并不需要想用的图像解码,可以直接使用opencv或者PIL进行处理。这里以cifar10的test数据为例。
import os
import pickle
import numpy as np
import mindspore
from mindspore.dataset import GeneratorDataset
class CIFAR10(object):
train_list = [
'data_batch_1',
'data_batch_2',
'data_batch_3',
'data_batch_4',
'data_batch_5',
]
test_list = [
'test_batch',
]
def __init__(self, root, train, transform=None, target_transform=None):
super(CIFAR10, self).__init__()
self.root = root
self.train = train # training set or test set
if self.train:
downloaded_list = self.train_list
else:
downloaded_list = self.test_list
self.data = []
self.targets = []
self.transform = transform
self.target_transform = target_transform
# now load the picked numpy arrays
for file_name in downloaded_list:
file_path = os.path.join(self.root, file_name)
with open(file_path, 'rb') as f:
entry = pickle.load(f, encoding='latin1')
self.data.append(entry['data'])
if 'labels' in entry:
self.targets.extend(entry['labels'])
else:
self.targets.extend(entry['fine_labels'])
self.data = np.vstack(self.data).reshape(-1, 3, 32, 32)
self.data = self.data.transpose((0, 2, 3, 1)) # convert to HWC
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, target = self.data[index], self.targets[index]
# doing this so that it is consistent with all other datasets
# to return a PIL Image
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
target = self.target_transform(target)
return img, target
def __len__(self):
return len(self.data)
cifar10_test = CIFAR10(root="./cifar10/cifar-10-batches-py", train=False)
cifar10_test = GeneratorDataset(source=cifar10_test, column_names=["image", "label"])
cifar10_test = cifar10_test.batch(128)
for data in cifar10_test.create_dict_iterator():
print(data["image"].shape, data["label"].shape)
(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)
(128, 32, 32, 3) (128,)可以从上面的代码看到,虽然语言风格不同,但是MindSpore使用GeneratorDataset依然可以为我们提供一套相对便利的数据集加载方式。对于数据集的预处理的transform代码,研究者可以将代码直接通过transform参数传入get_item函数,十分方便;同时也可以使用MindSpore语言风格,通过dataset自带的map函数,对数据集进行预处理,不过前者的语言风格更加Python,推荐使用。
● 自定义ImageNet ●
分析格式
接下来是介绍ImageNet的数据集自定义过程。其实定义ImageNet数据集加载器是非常方便的,因为图像分类的这类数据集往往是具有树状结构,我们只需要[路径,标签]或者是[图像,标签]的数组对传入到get_item函数中,就可以完成数据集的预处理。
数据加载
这里就简单引用timm中定义folder的部分代码。
def find_images_and_targets(folder, types=IMG_EXTENSIONS, class_to_idx=None, leaf_name_only=True, sort=True):
labels = []
filenames = []
for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
rel_path = os.path.relpath(root, folder) if (root != folder) else ''
label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
for f in files:
base, ext = os.path.splitext(f)
if ext.lower() in types:
filenames.append(os.path.join(root, f))
labels.append(label)
if class_to_idx is None:
# building class index
unique_labels = set(labels)
sorted_labels = list(sorted(unique_labels, key=natural_key))
class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
if sort:
images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
return images_and_targets, class_to_idx可以看到,我们只需要遍历目录,得到images_and_target就好。
Mixup和Cutmix的使用
在ImageNet中,我们常常会使用Mixup和Cutmix等数据增强,但是在对齐进行数据增强的时候,数据集已经是变成[batch_size, channel, height, width]形式出来的,在get_item进行数据预处理的函数是针对单个样本的。在PyTorch中,Mixup和Cutmix是在将数据取出,输入模型之前应用的。在MindSpore中,我们只需要在使用dataset.batch函数之后再对数据集进行预处理。具体的代码可以参考我的博客如何用MindSpore实现自动数据增强
(https://blog.csdn.net/qq_31768873/article/details/121283169),这里展示部分代码。
if (mix_up > 0. or cutmix > 0.) and not is_training:
# if use mixup and not training(False), one hot val data label
one_hot = C.OneHot(num_classes=num_classes)
dataset = dataset.map(input_columns="label", num_parallel_workers=num_parallel_workers,
operations=one_hot)
dataset = dataset.batch(batch_size, drop_remainder=True, num_parallel_workers=num_parallel_workers)
if (mix_up > 0. or cutmix > 0.) and is_training:
mixup_fn = Mixup(
mixup_alpha=mix_up, cutmix_alpha=cutmix, cutmix_minmax=None,
prob=mixup_prob, switch_prob=switch_prob, mode=mixup_mode,
label_smoothing=label_smoothing, num_classes=num_classes)
dataset = dataset.map(operations=mixup_fn, input_columns=["image", "label"],
num_parallel_workers=num_parallel_workers)
return dataset● FAQ ●
自定义数据集的时候,千万要注意要重载len函数,没有这个函数,对象是无法感知数据集大小的。
● 总结 ●
本文介绍了如何使用GeneratorDataset这个接口自定义MindSpore数据集。虽然MindSpore为我们提供了好用的专有数据算子,但是由于数据加载在C语言层面完成,相对于torchvision来说存在着无法感知的缺陷,因此可以尝试使用GeneratorDataset自定义加载,把握每一步细节。(当然,其实也可以去torchvision搬代码拿GeneratorDataset封装就好~)

MindSpore官方资料
GitHub : https://github.com/mindspore-ai/mindspore
Gitee : https : //gitee.com/mindspore/mindspore
官方QQ群 : 486831414
边栏推荐
- 电子协会 C语言 1级 31 、 计算线段长度
- A simple and crude method for exporting R language list to local
- Common techniques of email attachment phishing
- Nacos installation guide
- My advanced learning notes of C language ----- keywords
- 颜色搭配和相关问题
- Operator介绍
- 在手机开户买股票安全吗 网上开户炒股安全吗
- 不同的子序列问题I
- Would you like to buy stocks? Where do you open an account in a securities company? The Commission is lower and safer
猜你喜欢

Unity4.6 Download

12色彩环三原色
![[microservices] Understanding microservices](/img/62/e826e692e7fd6e6e8dab2baa4dd170.png)
[microservices] Understanding microservices

Service discovery, storage engine and static website of go language

开放世界机甲游戏-Phantom Galaxies

您的连接不是私密连接

Can't write to avoid killing and can easily go online CS through defender
![How to download on selenium computer -selenium download and installation graphic tutorial [ultra detailed]](/img/ec/1c324dcf38d07742a139aac2bab02e.png)
How to download on selenium computer -selenium download and installation graphic tutorial [ultra detailed]
![[微服務]認識微服務](/img/62/e826e692e7fd6e6e8dab2baa4dd170.png)
[微服務]認識微服務

PHP代码审计系列(一) 基础:方法、思路、流程
随机推荐
Nacos installation guide
Reading graph augmentations to learn graph representations (lg2ar)
一篇文章带你学会容器逃逸
股票怎样在手机上开户安全吗 网上开户炒股安全吗
Safe and cost-effective payment in Thailand
Open world mecha games phantom Galaxy
低佣金免费开户渠道安全吗?
Can I open an account for stock trading on my mobile phone? Is it safe to open an account for stock trading on the Internet
[microservice]eureka
Nacos安装指南
股票开户有哪些优惠活动?手机开户安全么?
Analysis on the advantages and disadvantages of the best 12 project management systems at home and abroad
Introduction to software engineering -- Chapter 4 -- formal description technology
Crawler and Middleware of go language
go语言的爬虫和中间件
On cap theorem in distributed system development technology
电子协会 C语言 1级 29 、 对齐输出
Electronic Society C language level 1 30, calculation of last term of arithmetic sequence
How to open an account on the mobile phone? Is it safe to open an account online and speculate in stocks
Is the low commission free account opening channel safe?