当前位置:网站首页>【pytorch学习笔记】Datasets and Dataloaders
【pytorch学习笔记】Datasets and Dataloaders
2022-07-03 14:53:00 【liiiiiiiiiiiiike】
为什么要单独设置Dataloaders
pytorch希望将数据集代码与模型训练代码分离,以此获得更好的可读性和模块化。pytorch提供两个接口函数:torch.utils.data.Dataloader 和 torch.utils.data.Dataset.加载自己的数据集和pytorch内置数据集。
加载数据集
下面演示从torchvision加载fashion-mnist数据集的示例,代码如下:
import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
root = 'data', # root是存储训练/测试数据的路径
train = True, # 指定训练或测试
download=True, # 如果数据不可用,则从网上下载数据到root
transform=ToTensor()
)
test_data = datasets.FashionMNIST(
root = 'data',
train = False,
download=True,
transform=ToTensor()
)
迭代和可视化数据集
我们可以让Datasets像列表一样手动索引,可视化代码如下:
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
sample_idx = torch.randint(len(training_data), size=(1,)).item()
img, label = training_data[sample_idx]
figure.add_subplot(rows, cols, i)
plt.title(labels_map[label])
plt.axis("off")
plt.imshow(img.squeeze(), cmap="gray")
plt.show()
创建自定义数据集
自定义dataset类必须实现三个函数,_init、len__和__getitem。将fashionmnist图像存储在一个目录img_dir,它们标签分别存储在一个CSV文件中。代码如下:
import os
import pandas as pd
from torchvision.io import read_image
class CustomImageDataset(Dataset):
def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
# 初始化
self.img_labels = pd.read_csv(annotations_file) #读取图片标签
self.img_dir = img_dir # 图片路径
self.transform = transform # 图片增强字段
self.target_transform = target_transform # 标签增强字段
def __len__(self):
#返回数据集样本数
return len(self.img_labels) #
def __getitem__(self, idx):
# 读取单个样本并调用图片变换函数,将变换后的函数返回
img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
image = read_image(img_path)
label = self.img_labels.iloc[idx, 1]
if self.transform:
image = self.transform(image)
if self.target_transform:
label = self.target_transform(label)
return image, label
使用DataLoaders为训练准备数据
在训练模型时,通常时mini-batch来训练,在每个epoch重新洗牌以减少模型过度拟合,并使用python multiprocessing加速数据检索。dataloader是一个可迭代对象:
from torch.utils.data import DataLoader
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
遍历DataLoader
已将数据集加载到DataLoader可以根据需要遍历数据集,下面每次迭代都会返回64个图像和标签。
# Display image and label.
train_features, train_labels = next(iter(train_dataloader)) # 64个图像和标签数据
print(f"Feature batch shape: {
train_features.size()}")
print(f"Labels batch shape: {
train_labels.size()}")
img = train_features[0].squeeze()# 选出第一个图像,并将batch维度压缩
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {
label}")
边栏推荐
- 7-9 one way in, two ways out (25 points)
- Yolov5 advanced nine target tracking example 1
- Global and Chinese market of postal automation systems 2022-2028: Research Report on technology, participants, trends, market size and share
- Zzuli:1058 solving inequalities
- Yolov5进阶之八 高低版本格式转换问题
- Neon global and Chinese markets 2022-2028: Research Report on technology, participants, trends, market size and share
- Mmdetection learning rate and batch_ Size relationship
- Global and Chinese markets for indoor HDTV antennas 2022-2028: Research Report on technology, participants, trends, market size and share
- 创业团队如何落地敏捷测试,提升质量效能?丨声网开发者创业讲堂 Vol.03
- Yolov5系列(一)——网络可视化工具netron
猜你喜欢
C string format (decimal point retention / decimal conversion, etc.)
[ue4] HISM large scale vegetation rendering solution
Remote server background hangs nohup
[engine development] in depth GPU and rendering optimization (basic)
ASTC texture compression (adaptive scalable texture compression)
C language fcntl function
4-24--4-28
Yolov5系列(一)——网络可视化工具netron
Introduction to opengl4.0 tutorial computing shaders
Centos7 deployment sentry redis (with architecture diagram, clear and easy to understand)
随机推荐
Zzuli:1042 sum of sequence 3
On MEM series functions of C language
Zzuli:1044 failure rate
Global and Chinese market of air cargo logistics 2022-2028: Research Report on technology, participants, trends, market size and share
Use of form text box (I) select text
C language to realize mine sweeping
QT program font becomes larger on computers with different resolutions, overflowing controls
CentOS7部署哨兵Redis(带架构图,清晰易懂)
Center and drag linked global and Chinese markets 2022-2028: Research Report on technology, participants, trends, market size and share
The latest M1 dedicated Au update Adobe audit CC 2021 Chinese direct installation version has solved the problems of M1 installation without flash back!
Yolov5 series (I) -- network visualization tool netron
QT - draw something else
【注意力机制】【首篇ViT】DETR,End-to-End Object Detection with Transformers网络的主要组成是CNN和Transformer
Web server code parsing - thread pool
1017 a divided by B (20 points)
My QT learning path -- how qdatetimeedit is empty
Yolov5 advanced nine target tracking example 1
Several sentences extracted from the book "leather bag"
Global and Chinese market of lighting control components 2022-2028: Research Report on technology, participants, trends, market size and share
[ue4] material and shader permutation