当前位置:网站首页>Pytorch框架学习记录1——Dataset类代码实战
Pytorch框架学习记录1——Dataset类代码实战
2022-07-30 03:54:00 【柚子Roo】
Pytorch框架学习记录1——Dataset类代码实战
介绍
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。- 通过
torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
这两个抽象类中用到的python知识点
能够熟练的使用python语言的技巧,是理解pytorch源码的关键。在torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类中会用到python抽象类的魔法方法,包括__len__(self),getitem(self)和__iter__(self)
__len__(self)定义当被len()函数调用时的行为(返回容器中元素的个数)__getitem__(self)定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。__iter__(self)定义当迭代容器中的元素的行为
数据集下载地址:https://pan.baidu.com/s/1qNCOVz15mCSQEDZZXJAaoQ?pwd=qz2b
提取码:qz2b
1. 导入包
from torch.utils.data import Dataset
from PIL import Image
import os
2. 创建类
创建子类MyData,继承父类Dataset,并对函数进行重写。
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(self.root_dir, self.label_dir) # 路径拼接
self.img_path = os.listdir(self.path) # 获得文件夹下所有的文件名
def __getitem__(self, idx):
img_name = self.img_path[idx]
img_item_path = os.path.join(self.root_dir, self.img_path, img_name)
img = Image.open(img_item_path)
label = self.label_dir
return img, label
def __len__(self):
return len(self.img_path)
3. 调用
类进行实例化,并进行调用
root_dir = "C:\\Users\\hp\\PycharmProjects\\pythonProject\\Pytorch_Learning\\flower_data\\train"
daisy_label_dir = "daisy"
roses_label_dir = "roses"
daisy_dataset = MyData(root_dir, daisy_label_dir)
roses_dataset = MyData(root_dir, roses_label_dir)
train_dataset = daisy_dataset + roses_dataset
print("daisy:",len(daisy_dataset),"\nroses",len(roses_dataset),"\ndaisy+roses",len(train_dataset))
img, label = train_dataset[0]
img2, label = train_dataset[577]
img2.show()
边栏推荐
猜你喜欢

Process priority nice

Mini Program Graduation Works WeChat Second-hand Trading Mini Program Graduation Design Finished Work (2) Mini Program Function

What is the difference between mission, vision and values?

一起来学习flutter 的布局组件

Nacos 安装与部署

小程序毕设作品之微信二手交易小程序毕业设计成品(8)毕业设计论文模板

Eureka注册中心

spicy(一)基本定义

Introduction to management for technical people 1: What is management

Nacos集群分区
随机推荐
Nacos 安装与部署
redis分布式锁的原子保证
Alibaba search new product data API by keyword
一直空、一直爽,继续抄顶告捷!
论坛管理系统
为什么突然间麒麟 9000 5G 版本,又有库存了?
EasyNVR平台级联到EasyCVR,视频播放一会就无法播放是什么原因?
小程序毕设作品之微信二手交易小程序毕业设计成品(6)开题答辩PPT
精品:淘宝/天猫获取购买到的商品订单详情 API
【Node访问MongoDB数据库】
操作配置:如何在一台服务器中以服务方式运行多个EasyCVR程序?
EasyCVR启动时报错“no such file or directory”,该如何解决?
MySQ deadlock
Boutique: Taobao/Tmall Get Order Details API for Purchased Products
运行时间监控:如何确保网络设备运行时间
Hystrix service circuit breaker
Transformation of traditional projects
Smart answer function, CRMEB knowledge payment system must have!
STM32 SPI+WM8978语音回环
【转】Swift 中的面向协议编程:引言