当前位置:网站首页>【pytorch学习笔记】Datasets and Dataloaders
【pytorch学习笔记】Datasets and Dataloaders
2022-07-03 14:53:00 【liiiiiiiiiiiiike】
为什么要单独设置Dataloaders
pytorch希望将数据集代码与模型训练代码分离,以此获得更好的可读性和模块化。pytorch提供两个接口函数:torch.utils.data.Dataloader 和 torch.utils.data.Dataset.加载自己的数据集和pytorch内置数据集。
加载数据集
下面演示从torchvision加载fashion-mnist数据集的示例,代码如下:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root = 'data', # root是存储训练/测试数据的路径
train = True, # 指定训练或测试
download=True, # 如果数据不可用,则从网上下载数据到root
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root = 'data',
train = False,
download=True,
transform=ToTensor()
)
迭代和可视化数据集
我们可以让Datasets像列表一样手动索引,可视化代码如下:
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
创建自定义数据集
自定义dataset类必须实现三个函数,_init、len__和__getitem。将fashionmnist图像存储在一个目录img_dir,它们标签分别存储在一个CSV文件中。代码如下:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 初始化
self.img_labels = pd.read_csv(annotations_file) #读取图片标签
self.img_dir = img_dir # 图片路径
self.transform = transform # 图片增强字段
self.target_transform = target_transform # 标签增强字段
def __len__(self):
#返回数据集样本数
return len(self.img_labels) #
def __getitem__(self, idx):
# 读取单个样本并调用图片变换函数,将变换后的函数返回
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
使用DataLoaders为训练准备数据
在训练模型时,通常时mini-batch来训练,在每个epoch重新洗牌以减少模型过度拟合,并使用python multiprocessing加速数据检索。dataloader是一个可迭代对象:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
遍历DataLoader
已将数据集加载到DataLoader可以根据需要遍历数据集,下面每次迭代都会返回64个图像和标签。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 64个图像和标签数据
print(f"Feature batch shape: {
train_features.size()}")
print(f"Labels batch shape: {
train_labels.size()}")
img = train_features[0].squeeze()# 选出第一个图像,并将batch维度压缩
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {
label}")
边栏推荐
- 7-10 stack of hats (25 points) (C language solution)
- PS tips - draw green earth with a brush
- Pytoch deep learning and target detection practice notes
- C language to implement a password manager (under update)
- CentOS7部署哨兵Redis(带架构图,清晰易懂)
- Incluxdb2 buckets create database
- Global and Chinese market of optical fiber connectors 2022-2028: Research Report on technology, participants, trends, market size and share
- C language DUP function
- NOI OPENJUDGE 1.5(23)
- Déformation de la chaîne bm83 de niuke (conversion de cas, inversion de chaîne, remplacement de chaîne)
猜你喜欢
How to color ordinary landscape photos, PS tutorial
C string format (decimal point retention / decimal conversion, etc.)
Yolov5 series (I) -- network visualization tool netron
链表有环,快慢指针走3步可以吗
High quality workplace human beings must use software to recommend, and you certainly don't know the last one
C language DUP function
5.4-5.5
Zero copy underlying analysis
Adobe Premiere Pro 15.4 has been released. It natively supports Apple M1 and adds the function of speech to text
[ue4] material and shader permutation
随机推荐
[ue4] geometry drawing pipeline
[ue4] HISM large scale vegetation rendering solution
[opengl] pre bake using computational shaders
Global and Chinese market of air cargo logistics 2022-2028: Research Report on technology, participants, trends, market size and share
Besides lying flat, what else can a 27 year old do in life?
Rasterization: a practical implementation (2)
Global and Chinese markets for sterile packaging 2022-2028: Research Report on technology, participants, trends, market size and share
TPS61170QDRVRQ1
【Transform】【NLP】首次提出Transformer,Google Brain团队2017年论文《Attention is all you need》
Zzuli:1049 square sum and cubic sum
C string format (decimal point retention / decimal conversion, etc.)
Zzuli:1047 logarithmic table
Zzuli:1046 product of odd numbers
[ue4] Niagara's indirect draw
Yolov5 advanced 8 format conversion between high and low versions
PS tips - draw green earth with a brush
C language memory function
How can entrepreneurial teams implement agile testing to improve quality and efficiency? Voice network developer entrepreneurship lecture Vol.03
Yolov5系列(一)——網絡可視化工具netron
Zzuli:1053 sine function