当前位置:网站首页>Pytorch读入据集(典型数据集及自定义数据集两种模式)
Pytorch读入据集(典型数据集及自定义数据集两种模式)
2022-06-24 07:24:00 【Hydrion-Qlz】
数据读入
Pytorch的数据读入是通过DataSet+DataLoader的方式完成的,DataSet定义好数据的格式和数据变换形式,DataLoader通过iterative的方式不断读入批次数据
读入已有的数据集
Pytorch自身支持很多的数据集,可以直接通过对应的函数得到对应的DataSet,然后传入DataLoader中等待处理:
例如读入MNIST数据集
from torchvision import datasets
import torchvision.transforms as transforms
transform = transforms.Compose([transforms.RandomHorizontalFlip,
transforms.RandomCrop,
transforms.ToTensor])
train = datasets.MNIST(root="./datasets",
train=True,
transform=transform,
download=True)
val = datasets.MNIST(root="./datasets",
train=False,
transform=transform,
download=True)
读入自己的数据集
另外也可以通过实现DataSet类来读入自己的数据集,一般来说需要实现三个函数:
__init__:用于向类中传入外部参数,同时定义样本集__getitem__:用于逐个读取样本集合中的元素,可以进行一定的变换,并将返回训练/验证所需的数据__len__:用户返回数据集的样本数
下面的例子是所有的图片存储在一个文件夹下面,同时在一个csv文件中保存有图片名称及其对应的标签
from PIL import Image
class CustomDataSet(Dataset):
def __init__(self, image_path, image_class, transform=None, device="cpu"):
self.image_path = image_path
self.image_class = image_class
self.transform = transform
self.device = device
def show_img(self, index):
plt.subplots(1, 1)
img = Image.open(self.image_path[index])
plt.imshow(img[2])
plt.show()
def __getitem__(self, index):
img = Image.open(self.image_path[index])
if img.mode != 'RGB':
raise ValueError("image:{} isn't RGB mode.".format(self.image_path[index]))
label = np.argmax(self.image_class[index])
label = torch.tensor(label).to(self.device)
if self.transform is not None:
img = self.transform(img)
return img.to(self.device), label
def __len__(self):
return len(self.image_path)
构建好DataSet之后就可以通过DataLoader读取自己的数据了
train_loader = DataLoader(train, batch_size, shuffle=True, drop_last=True)
val_loader = DataLoader(val, batch_size, shuffle=True, drop_last=False)
- shuffle:表示在加载的时候打乱顺序
- drop_last:丢弃掉最后不够一个batch的数据
全部设置完成之后就可以通过下面的函数不断的读取数据集了
for X, y in train_loader:
pass
边栏推荐
- [pytoch basic tutorial 31] youtubednn model analysis
- À propos de ETL il suffit de lire cet article, trois minutes pour vous faire comprendre ce qu'est ETL
- WebRTC系列-网络传输之5选择最优connection切换
- 110. balanced binary tree recursive method
- Centos7 installation of jdk8, mysql5.7 and Navicat connection to virtual machine MySQL and solutions (solutions to MySQL download errors are attached)
- 2138. splitting a string into groups of length k
- OpenCV每日函数 结构分析和形状描述符(7) 寻找多边形(轮廓)/旋转矩形交集
- 1704. judge whether the two halves of a string are similar
- 数据中台:数据中台全栈技术架构解析,附带行业解决方案
- 【牛客】把字符串转换成整数
猜你喜欢

MySQL | 存储《康师傅MySQL从入门到高级》笔记

数据中台:数据中台技术架构详解

从华为WeAutomate数字机器人论坛,看政企领域的“政务新智理”

關於ETL看這篇文章就够了,三分鐘讓你明白什麼是ETL

【E325: ATTENTION】vim编辑时报错

110. 平衡二叉树-递归法

Xiaohei ai4code code baseline nibble 1

Centos7 installation of jdk8, mysql5.7 and Navicat connection to virtual machine MySQL and solutions (solutions to MySQL download errors are attached)

陆奇:我现在最看好这四大技术趋势

基于QingCloud的 “房地一体” 云解决方案
随机推荐
Change SSH port number
[pytorch basic tutorial 30] code analysis of DSSM twin tower model
数据中台:中台架构及概述
[MySQL from introduction to mastery] [advanced part] (I) character set modification and underlying principle
Huawei Router: IPSec Technology
Distributed | how to make "secret calls" with dble
[quantitative investment] discrete Fourier transform to calculate array period
Fast and slow pointer series
开源之夏中选名单已公示,基础软件领域成为今年的热门申请
Analyze the meaning of Internet advertising terms CPM, CPC, CPA, CPS, CPL and CPR
【NOI模拟赛】摆(线性代数,杜教筛)
2020中国全国各省市,三级联动数据,数据机构(数据来自国家统计局官网)
Pymysql inserts data into MySQL and reports an error for no reason
2022-06-23:给定一个非负数组,任意选择数字,使累加和最大且为7的倍数,返回最大累加和。 n比较大,10的5次方。 来自美团。3.26笔试。
Using skills of xargs -- the way to build a dream
[force deduction 10 days SQL introduction] Day3
【MySQL从入门到精通】【高级篇】(一)字符集的修改与底层原理
所说的Get post:请求的区别,你真的知道了吗??????
【量化投资】离散傅里叶变换求数组周期
1844. 将所有数字用字符替换