当前位置:网站首页>【pytorch学习笔记】Datasets and Dataloaders
【pytorch学习笔记】Datasets and Dataloaders
2022-07-03 14:53:00 【liiiiiiiiiiiiike】
为什么要单独设置Dataloaders
pytorch希望将数据集代码与模型训练代码分离,以此获得更好的可读性和模块化。pytorch提供两个接口函数:torch.utils.data.Dataloader 和 torch.utils.data.Dataset.加载自己的数据集和pytorch内置数据集。
加载数据集
下面演示从torchvision加载fashion-mnist数据集的示例,代码如下:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root = 'data', # root是存储训练/测试数据的路径
train = True, # 指定训练或测试
download=True, # 如果数据不可用,则从网上下载数据到root
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root = 'data',
train = False,
download=True,
transform=ToTensor()
)
迭代和可视化数据集
我们可以让Datasets像列表一样手动索引,可视化代码如下:
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()

创建自定义数据集
自定义dataset类必须实现三个函数,_init、len__和__getitem。将fashionmnist图像存储在一个目录img_dir,它们标签分别存储在一个CSV文件中。代码如下:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 初始化
self.img_labels = pd.read_csv(annotations_file) #读取图片标签
self.img_dir = img_dir # 图片路径
self.transform = transform # 图片增强字段
self.target_transform = target_transform # 标签增强字段
def __len__(self):
#返回数据集样本数
return len(self.img_labels) #
def __getitem__(self, idx):
# 读取单个样本并调用图片变换函数,将变换后的函数返回
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
使用DataLoaders为训练准备数据
在训练模型时,通常时mini-batch来训练,在每个epoch重新洗牌以减少模型过度拟合,并使用python multiprocessing加速数据检索。dataloader是一个可迭代对象:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
遍历DataLoader
已将数据集加载到DataLoader可以根据需要遍历数据集,下面每次迭代都会返回64个图像和标签。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 64个图像和标签数据
print(f"Feature batch shape: {
train_features.size()}")
print(f"Labels batch shape: {
train_labels.size()}")
img = train_features[0].squeeze()# 选出第一个图像,并将batch维度压缩
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {
label}")

边栏推荐
- mmdetection 学习率与batch_size关系
- C language memory function
- My QT learning path -- how qdatetimeedit is empty
- Zzuli:1045 numerical statistics
- Vs+qt multithreading implementation -- run and movetothread
- How to color ordinary landscape photos, PS tutorial
- 5.2-5.3
- 1017 a divided by B (20 points)
- Fundamentals of PHP deserialization
- Troubleshooting method of CPU surge
猜你喜欢

Yolov5系列(一)——網絡可視化工具netron

C # realizes the login interface, and the password asterisk is displayed (hide the input password)

创业团队如何落地敏捷测试,提升质量效能?丨声网开发者创业讲堂 Vol.03

Yolov5 series (I) -- network visualization tool netron

Qt development - scrolling digital selector commonly used in embedded system

Adobe Premiere Pro 15.4 has been released. It natively supports Apple M1 and adds the function of speech to text

The picture quality has been improved! LR enhancement details_ Lightroom turns on AI photo detail enhancement: picture clarity increases by 30%
![[ue4] geometry drawing pipeline](/img/30/9fcf83a665043fe57389d44c2e16a8.jpg)
[ue4] geometry drawing pipeline

High quality workplace human beings must use software to recommend, and you certainly don't know the last one

Use of form text box (I) select text
随机推荐
Yolov5系列(一)——網絡可視化工具netron
[ue4] HISM large scale vegetation rendering solution
Bucket sorting in C language
[ue4] geometry drawing pipeline
Solve the problem that PR cannot be installed on win10 system. Pr2021 version -premiere Pro 2021 official Chinese version installation tutorial
B2020 分糖果
零拷贝底层剖析
5.4-5.5
从书本《皮囊》摘录的几个句子
什么是embedding(把物体编码为一个低维稠密向量),pytorch中nn.Embedding原理及使用
Niuke bm83 string deformation (case conversion, string inversion, string replacement)
[opengl] advanced chapter of texture - principle of flowmap
Yolov5 series (I) -- network visualization tool netron
Web server code parsing - thread pool
[opengl] geometry shader
Global and Chinese market of marketing automation 2022-2028: Research Report on technology, participants, trends, market size and share
Zzuli:1055 rabbit reproduction
5.2-5.3
. Net six design principles personal vernacular understanding, please correct if there is any error
NOI OPENJUDGE 1.3(06)