当前位置:网站首页>pytorch加载数据

pytorch加载数据

2022-07-06 03:18:00 七上八下的黑

一、 加载数据

  • 加载数据集里的单张图片
from PIL import Image

img_path="D:\\pycharm\\PycharmProjects\\learn_torch\\hymenoptera_data\\train\\ants\\0013035.jpg"
img = Image.open(img_path)
img.show()
  • 加载数据集列表 
import os

dir_path = "hymenoptera_data/train/ants"
img_path_list = os.listdir(dir_path)

在pycharm中的Python console(控制器)中运行,效果更直观。

二、运用pytorch加载数据

torch.utils 是torch常用的工具箱。

想根据 idx 获取相应的图片的话,先获取这个图片地址的list(获取dataset下的所有data)。

from torch.utils.data import Dataset
  •  定义 MyData 
class MyData(Dataset):

    '''定义全局变量'''
    def __init__(self, root_dir, label_dir):
        self.root_dir = root_dir  # root_dir 是数据集的目录
        self.label_dir = label_dir  # label_dir 是标签的目录
        self.path = os.path.join(self.root_dir, self.label_dir)
        self.img_path = os.listdir(self.path)
 
    '''获取数据'''
    def __getitem__(self, idx):
        img_name = self.img_path[idx]
        img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
        img = Image.open(img_item_path)
        label = self.label_dir
        return img, label

    '''数据集的长度'''
    def __len__(self):
        return len(self.img_path)
  •  测试定义的 MyData 
root_dir = "hymenoptera_data/train"
ants_label_dir = "ants"
bees_label_dir = "bees"
ants_dataset = MyData(root_dir, ants_label_dir)
bees_dataset = MyData(root_dir, bees_label_dir)

方法一: 

可在Python中测试: 

ants_dataset[0]

 方法二:

img, label = ants_dataset[0]
img.show()

 补充:

train_dataset = ants_dataset + bees_dataset  # 整个训练集(蚂蚁数据集和蜜蜂数据集的集合)

在数据集不够时,可以用这种方法补充数据集

原网站

版权声明
本文为[七上八下的黑]所创,转载请带上原文链接,感谢
https://blog.csdn.net/m0_52974810/article/details/125184454