当前位置:网站首页>torchvision.datasets.ImageFolder使用详解
torchvision.datasets.ImageFolder使用详解
2022-08-03 00:34:00 【*源仔】
一、数据集组织方式
ImageFolder是一个通用的数据加载器,它要求我们以下面这种格式来组织数据集的训练、验证或者测试图片。
root/1/xxx.png
root/1/xxy.png
root/1/xxz.png
. . .
root/2/12.png
. . .
root/3/123.png
. . .
root/4/356.png
. . .
对于上面的root,假设data文件夹在.py文件的同级目录中,那么root一般都是如下这种形式:./data/train 和 ./data/valid

二、ImageFolder参数详解
dataset=torchvision.datasets.ImageFolder(
root, transform=None,
target_transform=None,
loader=datasets.folder.default_loader,
is_valid_file=None)
参数详解:
- root:图片存储的根目录,即各类别文件夹所在目录的上一级目录。
- transform:对图片进行预处理的操作(函数),原始图片作为输入,返回一个转换后的图片。
- target_transform:对图片类别进行预处理的操作,输入为 target,输出对其的转换。 如果不传该参数,即对 target 不做任何转换,返回的顺序索引 0,1, 2…
- loader:表示数据集加载方式,通常默认加载方式即可。
- is_valid_file:获取图像文件的路径并检查该文件是否为有效文件的函数(用于检查损坏文件)
返回的dataset都有以下三种属性:
- self.classes:用一个 list 保存类别名称
- self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
- self.imgs:保存(img-path, class) tuple的 list
三、程序案例
from torchvision.datasets import ImageFolder
from torchvision import transforms
#加上transforms
normalize=transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5])
transform=transforms.Compose([
transforms.RandomCrop(180),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
normalize
])
dataset=ImageFolder('./data/train',transform=transform)
我们得到的dataset,它的结构就是[(img_data, class_id),(img_data, class_id),… ],下面我们打印第一个元素:
print(dataset[0])
输出:
(tensor([[[-0.5137, -0.4667, -0.4902, ..., -0.0980, -0.0980, -0.0902],
[-0.5922, -0.5529, -0.5059, ..., -0.0902, -0.0980, -0.0667],
[-0.5373, -0.5294, -0.4824, ..., -0.0588, -0.0824, -0.0196],
...,
[-0.3098, -0.3882, -0.3725, ..., -0.4353, -0.4510, -0.4196],
[-0.2863, -0.3647, -0.3725, ..., -0.4431, -0.4118, -0.4196],
[-0.3412, -0.3569, -0.3882, ..., -0.4667, -0.4588, -0.4196]],
[[-0.6157, -0.5686, -0.5922, ..., -0.2863, -0.2784, -0.2706],
[-0.6941, -0.6549, -0.6078, ..., -0.2784, -0.2784, -0.2471],
[-0.6392, -0.6314, -0.5843, ..., -0.2471, -0.2706, -0.2078],
...,
[-0.4431, -0.5059, -0.5059, ..., -0.5608, -0.5765, -0.5451],
[-0.4196, -0.4824, -0.5059, ..., -0.5686, -0.5373, -0.5451],
[-0.4745, -0.4902, -0.5294, ..., -0.5922, -0.5843, -0.5451]],
[[-0.6627, -0.6157, -0.6549, ..., -0.5059, -0.5216, -0.5137],
[-0.7412, -0.7020, -0.6706, ..., -0.4980, -0.5216, -0.4902],
[-0.6863, -0.6784, -0.6471, ..., -0.4667, -0.4902, -0.4275],
...,
[-0.6000, -0.6549, -0.6627, ..., -0.6784, -0.6941, -0.6627],
[-0.5765, -0.6314, -0.6471, ..., -0.6863, -0.6549, -0.6627],
[-0.6314, -0.6314, -0.6392, ..., -0.7098, -0.7020, -0.6627]]]), 0)
下面我们再看一下dataset的三个属性:
print(dataset.classes) #根据分的文件夹的名字来确定的类别
print(dataset.class_to_idx) #按顺序为这些类别定义索引为0,1...
print(dataset.imgs) #返回从所有文件夹中得到的图片的路径以及其类别
''' 输出: ['cat', 'dog'] {'cat': 0, 'dog': 1} [('./data/train\\cat\\1.jpg', 0), ('./data/train\\cat\\2.jpg', 0), ('./data/train\\dog\\1.jpg', 1), ('./data/train\\dog\\2.jpg', 1)] '''
自己编写datasets.ImageFolder
class CustomImageFolderDataset(datasets.ImageFolder):
def __init__(self,
root,
transform=None,
target_transform=None,
loader=datasets.folder.default_loader,
is_valid_file=None,
low_res_augmentation_prob=0.0,
crop_augmentation_prob=0.0,
photometric_augmentation_prob=0.0,
):
super(CustomImageFolderDataset, self).__init__(root,
transform=transform,
target_transform=target_transform,
loader=loader,
is_valid_file=is_valid_file)
self.root = root
self.low_res_augmentation_prob = low_res_augmentation_prob
self.crop_augmentation_prob = crop_augmentation_prob
self.photometric_augmentation_prob = photometric_augmentation_prob
self.random_resized_crop = transforms.RandomResizedCrop(size=(112, 112),
scale=(0.2, 1.0),
ratio=(0.75, 1.3333333333333333))
self.photometric = transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0)
self.tot_rot_try = 0
self.rot_success = 0
def __getitem__(self, index):
""" Args: index (int): Index Returns: tuple: (sample, target) where target is class_index of the target class. """
path, target = self.samples[index]
sample = self.loader(path)
if 'WebFace' in self.root:
# swap rgb to bgr since image is in rgb for webface
# 将 rgb 交换为 bgr,因为图像在 rgb 中用于 webface
sample = Image.fromarray(np.asarray(sample)[:,:,::-1])
sample, _ = self.augment(sample)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def augment(self, sample):
# crop with zero padding augmentation
if np.random.random() < self.crop_augmentation_prob:
# RandomResizedCrop augmentation
new = np.zeros_like(np.array(sample))
orig_W, orig_H = F._get_image_size(sample)
i, j, h, w = self.random_resized_crop.get_params(sample,
self.random_resized_crop.scale,
self.random_resized_crop.ratio)
cropped = F.crop(sample, i, j, h, w)
new[i:i+h,j:j+w, :] = np.array(cropped)
sample = Image.fromarray(new.astype(np.uint8))
crop_ratio = min(h, w) / max(orig_H, orig_W)
else:
crop_ratio = 1.0
# low resolution augmentation
if np.random.random() < self.low_res_augmentation_prob:
# low res augmentation
img_np, resize_ratio = low_res_augmentation(np.array(sample))
sample = Image.fromarray(img_np.astype(np.uint8))
else:
resize_ratio = 1
# photometric augmentation
if np.random.random() < self.photometric_augmentation_prob:
fn_idx, brightness_factor, contrast_factor, saturation_factor, hue_factor = \
self.photometric.get_params(self.photometric.brightness, self.photometric.contrast,
self.photometric.saturation, self.photometric.hue)
for fn_id in fn_idx:
if fn_id == 0 and brightness_factor is not None:
sample = F.adjust_brightness(sample, brightness_factor)
elif fn_id == 1 and contrast_factor is not None:
sample = F.adjust_contrast(sample, contrast_factor)
elif fn_id == 2 and saturation_factor is not None:
sample = F.adjust_saturation(sample, saturation_factor)
information_score = resize_ratio * crop_ratio
return sample, information_score
边栏推荐
- 机电设备制造企业,如何借助ERP系统做好客供料管理?
- 流程控制for和while循环语句
- Auto.js special positioning control method cannot perform blocking operations on the ui thread, please use setTimeout instead
- 和睦家私有化后换帅:新风天域吴启楠任CEO 李碧菁靠边站
- 关于地图GIS开发事项的一次实践整理(上)
- Teach you to locate online MySQL slow query problem hand by hand, package teaching package meeting
- 封装和练习题目
- 【多线程】线程与进程、以及线程进程的调度
- OpenWRT设置ipv6网络
- 牛客网剑指offer刷题练习之链表中环的入口结点
猜你喜欢

浅谈I2C知识

心电记录电路设计(框图/波形以及信号放大器的选择)

Auto.js special positioning control method cannot perform blocking operations on the ui thread, please use setTimeout instead

10. SAP ABAP OData 服务如何支持修改(Update)操作

Day117. Shangyitong: Generate registered order module

NVM和NRM

Oracle 暴跌,倒下了!

软件测试从业多年,自认为技术不错,裸辞:一晃 ,失业3个月了~

js显示隐藏手机号
![[NCTF2019]SQLi-1||SQL注入](/img/18/6483cd9d5d2722860652fea193c13a.png)
[NCTF2019]SQLi-1||SQL注入
随机推荐
Linear DP
Rasa 3.x study series - Rasa - Issues 4792 socket debug logs clog up debug feed study notes
JS做一个接近无限时长的滚动条
dataBinding的import导入
Greenplum数据库故障分析——can not listen port
UPC2022暑期个人训练赛第23场(Credit Card Payment)
Carefully organize 16 MySQL usage specifications to reduce problems by 80% and recommend sharing with the team
线性DP
担心的事情
【Autosar RTM】
Auto.js 特殊定位控件方法 不能在ui线程执行阻塞操作,请使用setTimeout代替
Jenkins汉化设置
电压传感器: 工作原理、类型及电路图
8 个常用的 Wireshark 使用技巧,一看就会
绿色版-SQL环境搭建
软件测试从业多年,自认为技术不错,裸辞:一晃 ,失业3个月了~
优秀论文以及思路分析01
如何快速对接淘宝开放平台API接口(淘宝店铺订单明文接口,淘宝店铺商品上传接口,淘宝店铺订单交易接口)
236. 二叉树的最近公共祖先
精心整理16条MySQL使用规范,减少80%问题,推荐分享给团队