当前位置:网站首页>Pytorch learning notes 3 - datasets & dataloaders & transforms
Pytorch learning notes 3 - datasets & dataloaders & transforms
2022-07-28 06:28:00 【I have two candies】
List of articles
Data reading is the first step of deep learning ,PyTorch Provides torch.utils.data.DataLoader and torch.utils.data.Dataset Two Module Let's read online datasets and our own datasets .
PyTorch There are many preloaded datasets , Such as FashionMNIST, They are all torch.utils.data.Dataset Subclasses of , They can be found here : Image Datasets, Text Datasets, and Audio Datasets
1. Download datasets
Fashion-MNIST Is a data set of clothing images , contain 60000 A training sample and 10000 A test sample , The size of each sample is 28 × 28 28 \times 28 28×28 The gray image , It includes 10 Class image , To load a dataset, you need to specify the following parameters :
rootis the path where the train/test data is stored,trainspecifies training or test dataset,download=Truedownloads the data from the internet if it’s not available at root.transformandtarget_transformspecify the feature and label transformations
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root='data',
train=False,
download=True,
transform=ToTensor()
)
print(len(training_data)) # 60000
2. Iteration and visualization of data sets
The resulting data set is torchvision.datasets.mnist.FashionMNIST object , Supporting image list Iterate in the same way :training_data[index]
print(type(training_data))
print(len(training_data)) # 60000
X, y = training_data[0]
print(f'img[0].shape = {
X.shape}') # torch.Size([1, 28, 28])
print(f'label[0] = {
y}') # 9
Use the following method to iterate training_data The elements of :
for i in range(len(training_data)):
X, y = training_data[i]
for X, y in training_data:
pass
Use matplotlib Visualize datasets :
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
X, y = training_data[0]
print(f'img[0].shape = {
X.shape}')
print(f'label[0] = {
y}')
figure = plt.figure()
cols, rows = 4, 4
for i in range(1, cols * rows + 1):
sample_index = torch.randint(
low=0,
high=len(training_data),
size=(1,)).item()
img, label = training_data[sample_index]
# print(img.shape) # torch.Size([1, 28, 28])
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis('off')
plt.imshow(img.squeeze(), cmap='gray')
plt.show()
among ,torch.squeeze(input, dim) You can input tensor Of 1 Delete the dimension of ,dim The default is all dimensions , Appoint dim after , if dim Dimension size is 1, Delete , Otherwise, it will not be deleted , Such as :
t = torch.randint(0, 10, size=(1, 28, 1, 28))
print(f't.shape = {
t.shape}')
print(f't.squeeze().shape = {
t.squeeze().shape}')
print(f't.squeeze(dim=2).shape = {
t.squeeze(dim=2).shape}')
print(f'torch.squeeze(input=t, dim=1).shape = {
torch.squeeze(input=t, dim=1).shape}')
# t.shape = torch.Size([1, 28, 1, 28])
# t.squeeze().shape = torch.Size([28, 28])
# t.squeeze(dim=2).shape = torch.Size([1, 28, 28])
# torch.squeeze(input=t, dim=1).shape = torch.Size([1, 28, 1, 28])
As a result of training_data Read from X Of shape by torch.Size([1, 28, 28]), No use pyplot draw , Use squeeze after shape become torch.Size([28, 28]), You can use pyplot draw , The result is as follows :

3. Read your own data set
To read a custom dataset, you need to define three functions __init__,__len__, and __getitem__, take FashionMNIST The image is stored in img_dir, Tags are stored in CSV file annotations_file in :
import os
import pandas as pd
from torch.utils.data import Dataset
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# initialize the directory containing the images, the annotations file, and both transforms
self.img_labels = pd.read_csv(annotations_file)
self.img_dir = img_dir
self.transform = transform
self.target_transform = target_transform
def __len__(self):
# returns the number of samples in our dataset.
return len(self.img_labels)
def __getitem__(self, idx):
# loads and returns a sample from the dataset at the given index idx
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
(1) __init__ Initialize the directory of the dataset 、 Label files and transform,labels.csv The documents are as follows :
tshirt1.jpg, 0
tshirt2.jpg, 0
......
ankleboot999.jpg, 9
(2)__len__ Returns the number of samples contained in the dataset
(3)__getitem__ It realizes the process of obtaining samples in the data set through index image and label
4. DataLoader
DataLoader The data set can be divided into several minibatch, You can specify whether to use random scrambling shuffle
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
for X, y in test_dataloader:
print(f'Shape of X [N, C, H, W]: {
X.shape}') # torch.Size([64, 1, 28, 28])
print(f'Shape of y: {
y.shape} {
y.dtype}') # torch.Size([64])
DataLoader The returned object is iteratable :
# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {
train_features.size()}") # torch.Size([64, 1, 28, 28])
print(f"Labels batch shape: {
train_labels.size()}") # torch.Size([64])
# show image[0]
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {
label}")
5. TRANSFORMS
transforms The format of the data set can be converted into a format convenient for training ,TorchVision All data sets have two parameters : Used to modify features -transform , And for modifying labels -target_transform
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root="data",
train=True,
download=True,
transform=ToTensor(),
target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)
among ,ToTensor Can be PIL Image or numpy Matrix into FloatTensor, And convert the gray value of the image to [0. 1] Within the scope of ;
target_transform Specifies to use a custom lambda transforms, The following code converts the label from an integer to multiply one-hot Label in encoded form (scatter_ take label y The corresponding position becomes 1):
target_transform = Lambda(lambda y: torch.zeros(
10, dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
REDERENCE:
1 . https://pytorch.org/tutorials/beginner/basics/data_tutorial.html
2 . torch.utils.data API
3 . TRANSFORMS
more PyTorch Introduction learning notes reference PyTorch Learning notes
边栏推荐
- PyTorch 学习笔记 1 —— Quick Start
- 听说你也在实习|当我采访了几个大三实习生之后。
- Why should fluke dsx2-5000 network cable tester be calibrated once a year?
- 福禄克DSX2-5000、DSX2-8000模块如何找到校准到期日期?
- Systemmediasize startup option added in esxi 7.0 update 1C
- 雷达成像 Matlab 仿真 4 —— 距离分辨率分析
- Hugging face 的入门使用
- 天线效应解决办法
- Briefly introduce EMD decomposition, Hilbert transform and spectral method
- Beta distribution (probability of probability)
猜你喜欢

机器学习笔记 5 —— Logistic Regression

Fluke dtx-sfm2 single mode module of a company in Hangzhou - repair case

Selection of PLC

clickhouse聚合之内存不足怎么办?那就提升聚合性能

Briefly introduce EMD decomposition, Hilbert transform and spectral method

论文神器 VS Code + LaTex + LaTex Workshop

What about the insufficient memory of Clickhouse aggregation? Then improve the polymerization performance

An example of bill printing

Transformer self attention mechanism and complete code implementation

Synopsys Multivoltage Flow
随机推荐
Varistor design parameters and classic circuit recording hardware learning notes 5
Talk about the "hybrid mode" of esxi virtual switch and port group
雷达成像 Matlab 仿真 4 —— 距离分辨率分析
EXFO 730c optical time domain reflectometer only has IOLm optical eye to upgrade OTDR (open OTDR permission)
set_clock_groups
Low power design -power switch
浅谈误码仪的使用场景?
Web scrolling subtitles (marquee example)
mysql join技巧
Precautions for EMI design of switching power supply circuit in layout process
Overall understanding of PLC
How can fluke dsx2-5000 and dsx2-8000 modules find the calibration expiration date?
PyTorch 学习笔记 3 —— DATASETS & DATALOADERS & TRANSFORMS
I heard that you are also practicing when I interviewed several junior interns.
Bert bidirectional encoder based on transformer
当mysql表从压缩表变成普通表会发生什么
ICC2使用report_placement检查floorplan
How to use the bit error meter?
IMS-FACNN(Improved Multi-Scale Convolution Neural Network integrated with a Feature Attention Mecha
MAE 掩码自编码是可扩展的学习