当前位置:网站首页>三、如何搞自定义数据集?
三、如何搞自定义数据集?
2022-07-29 05:22:00 【MY头发乱了】
前言
MNIST数据这个最最基础的数据集已经被走在程序猿道路上的同学们玩坏了,所以今天教大家如何搞一个自定义数据集。
一、定义的数据集,未做预处理。
下面展示一些 内联代码片。
import os
from torch.utils.data import Dataset ,DataLoader
from PIL import Image
#1.创建数据集类,使用torch.utils.data中的Dataset方法。
class My_Dataset(Dataset):
#2.循环找到文件路径,并添加标签
def __init__(self,main_dir,data_type,transforms):
self.dataset=[]#空列表为装新增一个标签的数据库
self.transforms=transforms
if data_type==0:
data_filename='train'
elif data_type is 1:
data_filename='val'
else:
data_filename='test'
for i , cls_filename in enumerate(
os.listdir(os.path.join(main_dir,data_filename))):
for i ,img_data in enumerate(os.listdir(
os.path.join(main_dir,data_filename,cls_filename))):
self.dataset.append([os.path.join(main_dir,
data_filename,cls_filename,img_data),int(img_data[0]) ])
#3.计算图片长度,方便后面迭代
def __len__(self):
return len(self.dataset)#为了获取图片长度,方便迭代
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=Image.open(img)
img_data=self.transforms(img_data)
return img_data,label示例:@TOC
二、定义数据集,并做数据预处理。
包括旋转、裁剪、转为张量、扩大、正则化等等。
1.预处理部分
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=self.data_process(Image.open(img))
return img_data,label
#5.数据处理,数据增强、加噪声等等
def data_process(self,x):
return transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),std=(0.5,))])(x)2.定义数据过程
代码如下(示例):
import os
from torchvision import transforms
from torch.utils.data import Dataset ,DataLoader
from PIL import Image
#1.创建数据集类,使用torch.utils.data中的Dataset方法。
class My_Dataset(Dataset):
#2.循环找到文件路径,并添加标签
def __init__(self,main_dir,data_type):
self.dataset=[]#空列表为装新增一个标签的数据库
if data_type==0:
data_filename='train'
elif data_type is 1:
data_filename='val'
else:
data_filename='test'
for i , cls_filename in enumerate(
os.listdir(os.path.join(main_dir,data_filename))):
for i ,img_data in enumerate(os.listdir(
os.path.join(main_dir,data_filename,cls_filename))):
self.dataset.append([os.path.join(main_dir,
data_filename,cls_filename,img_data),i ])
#3.计算图片长度,方便后面迭代
def __len__(self):
return len(self.dataset)#为了获取图片长度,方便迭代
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=self.data_process(Image.open(img))
return img_data,label
#5.数据处理,数据增强、加噪声等等
def data_process(self,x):
return transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),std=(0.5,))])(x)
``边栏推荐
- 电脑视频暂停再继续,声音突然变大
- 通过简单的脚本在Linux环境实现Mysql数据库的定时备份(Mysqldump命令备份)
- clion+opencv+aruco+cmake配置
- Isaccessible() method: use reflection techniques to improve your performance several times
- Ffmpeg creation GIF expression pack tutorial is coming! Say thank you, brother black fly?
- [DL] build convolutional neural network for regression prediction (detailed tutorial of data + code)
- ROS教程(Xavier)
- Ribbon learning notes 1
- Operation commands in anaconda, such as removing old environment, adding new environment, viewing environment, installing library, cleaning cache, etc
- Android studio login registration - source code (connect to MySQL database)
猜你喜欢

【Transformer】ACMix:On the Integration of Self-Attention and Convolution

【ML】机器学习模型之PMML--概述

【目标检测】6、SSD
![[target detection] generalized focal loss v1](/img/8b/458d51422df8dcda65cb6afaa10b3f.png)
[target detection] generalized focal loss v1

ABSA1: Attentional Encoder Network for Targeted Sentiment Classification

GAN:生成对抗网络 Generative Adversarial Networks

一、Focal Loss理论及代码实现

研究生新生培训第一周:深度学习和pytorch基础
![[ml] PMML of machine learning model -- Overview](/img/a1/cd3eff044d903dbcfb880e854713e5.png)
[ml] PMML of machine learning model -- Overview

并发编程学习笔记 之 工具类Semaphore(信号量)
随机推荐
【Attention】Visual Attention Network
mysql 的show profiles 使用。
mysql在查询字符串类型的时候带单引号和不带的区别和原因
【网络设计】ConvNeXt:A ConvNet for the 2020s
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
[network design] convnext:a convnet for the 2020s
备份谷歌或其他浏览器插件
[target detection] generalized focal loss v1
Ffmpeg creation GIF expression pack tutorial is coming! Say thank you, brother black fly?
D3.js vertical relationship diagram (with arrows and text description of connecting lines)
【语义分割】SETR_Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformer
第一周任务 深度学习和pytorch基础
GA-RPN:引导锚点的建议区域网络
PyTorch中的模型构建
【pycharm】pycharm远程连接服务器
迁移学习——Transfer Joint Matching for Unsupervised Domain Adaptation
Valuable blog and personal experience collection (continuous update)
[DL] introduction and understanding of tensor
GAN:生成对抗网络 Generative Adversarial Networks
Lock lock of concurrent programming learning notes and its implementation basic usage of reentrantlock, reentrantreadwritelock and stampedlock