当前位置:网站首页>三、如何搞自定义数据集?
三、如何搞自定义数据集?
2022-07-29 05:22:00 【MY头发乱了】
前言
MNIST数据这个最最基础的数据集已经被走在程序猿道路上的同学们玩坏了,所以今天教大家如何搞一个自定义数据集。
一、定义的数据集,未做预处理。
下面展示一些 内联代码片
。
import os
from torch.utils.data import Dataset ,DataLoader
from PIL import Image
#1.创建数据集类,使用torch.utils.data中的Dataset方法。
class My_Dataset(Dataset):
#2.循环找到文件路径,并添加标签
def __init__(self,main_dir,data_type,transforms):
self.dataset=[]#空列表为装新增一个标签的数据库
self.transforms=transforms
if data_type==0:
data_filename='train'
elif data_type is 1:
data_filename='val'
else:
data_filename='test'
for i , cls_filename in enumerate(
os.listdir(os.path.join(main_dir,data_filename))):
for i ,img_data in enumerate(os.listdir(
os.path.join(main_dir,data_filename,cls_filename))):
self.dataset.append([os.path.join(main_dir,
data_filename,cls_filename,img_data),int(img_data[0]) ])
#3.计算图片长度,方便后面迭代
def __len__(self):
return len(self.dataset)#为了获取图片长度,方便迭代
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=Image.open(img)
img_data=self.transforms(img_data)
return img_data,label
示例:@TOC
二、定义数据集,并做数据预处理。
包括旋转、裁剪、转为张量、扩大、正则化等等。
1.预处理部分
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=self.data_process(Image.open(img))
return img_data,label
#5.数据处理,数据增强、加噪声等等
def data_process(self,x):
return transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),std=(0.5,))])(x)
2.定义数据过程
代码如下(示例):
import os
from torchvision import transforms
from torch.utils.data import Dataset ,DataLoader
from PIL import Image
#1.创建数据集类,使用torch.utils.data中的Dataset方法。
class My_Dataset(Dataset):
#2.循环找到文件路径,并添加标签
def __init__(self,main_dir,data_type):
self.dataset=[]#空列表为装新增一个标签的数据库
if data_type==0:
data_filename='train'
elif data_type is 1:
data_filename='val'
else:
data_filename='test'
for i , cls_filename in enumerate(
os.listdir(os.path.join(main_dir,data_filename))):
for i ,img_data in enumerate(os.listdir(
os.path.join(main_dir,data_filename,cls_filename))):
self.dataset.append([os.path.join(main_dir,
data_filename,cls_filename,img_data),i ])
#3.计算图片长度,方便后面迭代
def __len__(self):
return len(self.dataset)#为了获取图片长度,方便迭代
#4、取出图片路径,并打开,便于做数据预处理
def __getitem__(self, index):
img,label=self.dataset[index]
img_data=self.data_process(Image.open(img))
return img_data,label
#5.数据处理,数据增强、加噪声等等
def data_process(self,x):
return transforms.Compose([transforms.ToTensor(),
transforms.Normalize(mean=(0.5,),std=(0.5,))])(x)
``
边栏推荐
- PyTorch的数据读取机制
- 【ML】机器学习模型之PMML--概述
- 虚假新闻检测论文阅读(二):Semi-Supervised Learning and Graph Neural Networks for Fake News Detection
- yum本地源制作
- D3.js vertical relationship diagram (with arrows and text description of connecting lines)
- [target detection] KL loss: bounding box progression with uncertainty for accurate object detection
- 【Transformer】AdaViT: Adaptive Vision Transformers for Efficient Image Recognition
- 【语义分割】Mapillary 数据集简介
- 【Attention】Visual Attention Network
- [semantic segmentation] overview of semantic segmentation
猜你喜欢
虚假新闻检测论文阅读(一):Fake News Detection using Semi-Supervised Graph Convolutional Network
深入理解MMAP原理,让大厂都爱不释手的技术
Ffmpeg creation GIF expression pack tutorial is coming! Say thank you, brother black fly?
Technology that deeply understands the principle of MMAP and makes big manufacturers love it
【DL】关于tensor(张量)的介绍和理解
Windos下安装pyspider报错:Please specify --curl-dir=/path/to/built/libcurl解决办法
研究生新生培训第三周:ResNet+ResNeXt
【语义分割】Fully Attentional Network for Semantic Segmentation
"Full flash measurement" database acceleration solution
Operation commands in anaconda, such as removing old environment, adding new environment, viewing environment, installing library, cleaning cache, etc
随机推荐
并发编程学习笔记 之 原子操作类AtomicReference、AtomicStampedReference详解
【目标检测】KL-Loss:Bounding Box Regression with Uncertainty for Accurate Object Detection
[DL] introduction and understanding of tensor
一、常见损失函数的用法
【pycharm】pycharm远程连接服务器
[CV] what are the specific numbers of convolution kernels (filters) 3*3, 5*5, 7*7 and 11*11?
第三周周报 ResNet+ResNext
Ribbon learning notes 1
二、OCR训练时,将txt文件和图片数据转为lmdb文件格式
【综述】图像分类网络
MarkDown简明语法手册
[DL] build convolutional neural network for regression prediction (detailed tutorial of data + code)
[competition website] collect machine learning / deep learning competition website (continuously updated)
【Attention】Visual Attention Network
Nailing alarm script
Configuration and use of Nacos external database
Detailed explanation of atomic operation class atomicinteger in learning notes of concurrent programming
clion+opencv+aruco+cmake配置
Flink connector Oracle CDC synchronizes data to MySQL in real time (oracle12c)
Markdown syntax