当前位置:网站首页>Data loading and preprocessing
Data loading and preprocessing
2022-07-01 04:46:00 【CyrusMay】
Pytorch( Four ) —— Data preprocessing
1. Use torch.utils.data.Dataset Read the data
- Read data by inheriting this class
The file path is :
import torch
from torch.utils.data import Dataset,DataLoader
import os
import csv
import glob
import random
from PIL import Image
from torchvision import transforms
import visdom
from torchvision.datasets import ImageFolder
class AnimalData(Dataset):
def __init__(self,root,resize = [28,28],mode="train"):
super(AnimalData,self).__init__()
self.root = root
self.resize = resize # [h,w]
# Get the labels of each category according to the name of the subfolder
self.class2label = {
}
for name in sorted(os.listdir(os.path.join(self.root))):
if not os.path.isdir(os.path.join(self.root,name)):
continue
self.class2label[name] = len(self.class2label.keys())
print(self.class2label)
# from csv The storage path and label of the loaded data in the file
images,labels = self.load_csv("animal.csv")
# According to the requirements of the task , Return the data
if mode == "train":
self.images = images[:int(0.6*len(images))]
self.labels = labels[:int(0.6*len(images))]
elif mode == "val":
self.images = images[int(0.6 * len(images)):int(0.8 * len(images))]
self.labels = labels[int(0.6 * len(images)):int(0.8 * len(images))]
elif mode == "test":
self.images = images[int(0.8 * len(images)):]
self.labels = labels[int(0.8 * len(images)):]
def load_csv(self,file_name):
if not os.path.exists(file_name):
images = []
for name in self.class2label.keys():
# glob.glob() Method can match the files in this path , Return to the full path
images += glob.glob(os.path.join(self.root,name,"*.png"))
images += glob.glob(os.path.join(self.root,name,".jpg"))
# Scrambling data
random.shuffle(images)
# write in csv file , Easy to read next time
with open(file_name,"w",encoding="utf-8",newline="") as f:
writer = csv.writer(f)
for path in images:
name = path.split(os.sep)[1]
label = self.class2label[name]
writer.writerow([path,label])
# adopt csv Load data
with open(file_name,"r",encoding="utf-8") as f:
reader = csv.reader(f)
images = []
labels = []
for line in reader:
images.append(line[0])
labels.append(int(line[1]))
return images,labels
# Override the method , Returns the data size
def __len__(self):
return len(self.images)
# Anti standardization , Easy to visualize
def de_normalize(self,x_hat):
mean = torch.tensor([0.485, 0.456, 0.406]).unsqueeze(1).unsqueeze(1)
std = torch.tensor([0.229, 0.224, 0.225]).unsqueeze(1).unsqueeze(1)
x = x_hat *std + mean
return x
# Override the method , return Tensor Format data and labels
def __getitem__(self,idx):
label = torch.tensor(self.labels[idx])
tf = transforms.Compose([
lambda x: Image.open(x).convert("RGB"), # Read the picture
transforms.Resize([int(self.resize[0]*1.25),int(self.resize[1]*1.25)]),
transforms.RandomRotation(15), # Data to enhance
transforms.CenterCrop(self.resize), # Centralized cutting
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = tf(self.images[idx])
return image,label
if __name__ == '__main__':
resize = [128,100]
db = AnimalData(root="animal",resize=resize)
{'cat': 0, 'dog': 1, 'rabbit': 2}
2. Use torch.utils.data.DataLoader Load data
if __name__ == '__main__':
resize = [128,100]
db = AnimalData(root="animal",resize=resize)
it_db = iter(db)
vis = visdom.Visdom()
image,label = next(it_db)
vis.image(db.de_normalize(image),win="iter_image",opts=dict(title="iter_image"))
# Using a data loader , Set up batch
loader = DataLoader(dataset=db,batch_size=16,shuffle=True,num_workers=8) # num_workers The parameter is multi thread reading data
for x,y in loader:
vis.images(db.de_normalize(x),win="batch_imags",nrow=4,opts=dict(title="batch"))
3. Use torchvision.datasets.ImageFolder For fast data reading
# ImageFolder The above process can be realized in one step
tf = transforms.Compose([
transforms.Resize([int(resize[0] * 1.25), int(resize[1] * 1.25)]),
transforms.RandomRotation(15), # Data to enhance
transforms.CenterCrop(resize), # Centralized cutting
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
db = ImageFolder(root = "animal",
transform=tf)
by CyrusMay 2022 06 30
How many twists and turns do you have in your life
To go to the other side of happiness
can Live this life without regret
Ordinary but not plain
—————— May day ( Qingkong future )——————
边栏推荐
- Shell之分析服务器日志命令集锦
- Pytorch(二) —— 激活函数、损失函数及其梯度
- 2022 tea master (intermediate) examination question bank and tea master (intermediate) examination questions and analysis
- Codeforces Round #771 (Div. 2) ABCD|E
- 【硬十宝典目录】——转载自“硬件十万个为什么”(持续更新中~~)
- VIM简易使用教程
- 科研狗可能需要的一些工具
- Use and modification of prior network model
- Introduction to JVM stack and heap
- Neural networks - use of maximum pooling
猜你喜欢
C -- array
Extension fragment
AssertionError assert I.ndim == 4 and I.shape[1] == 3
STM32 extended key scan
C - detailed explanation of operators and summary of use cases
Openresty rewrites the location of 302
Why is Internet thinking not suitable for AI products?
技术分享| 融合调度中的广播功能设计
2022 a special equipment related management (elevator) simulation test and a special equipment related management (elevator) certificate examination
Strategic suggestions and future development trend of global and Chinese vibration isolator market investment report 2022 Edition
随机推荐
Neural networks - use sequential to build neural networks
Neural networks - use of maximum pooling
Use and modification of prior network model
Leecode question brushing record 1310 subarray XOR query
Announcement on the list of Guangdong famous high-tech products to be selected in 2021
分布式数据库数据一致性的原理、与技术实现方案
C#读写应用程序配置文件App.exe.config,并在界面上显示
Collect the annual summary of laws, regulations, policies and plans related to trusted computing of large market points (national, ministerial, provincial and municipal)
Section 27 remote access virtual private network workflow and experimental demonstration
Common interview questions ①
Question bank and answers for chemical automation control instrument operation certificate examination in 2022
LeetCode_58(最后一个单词的长度)
[hard ten treasures] - 1 [basic knowledge] classification of power supply
神经网络-使用Sequential搭建神经网络
Tencent has five years of testing experience. It came to the interview to ask for 30K, and saw the so-called software testing ceiling
常用的Transforms中的方法
STM32 光敏电阻传感器&两路AD采集
How to view the changes and opportunities in the construction of smart cities?
I also gave you the MySQL interview questions of Boda factory. If you need to come in and take your own
Pytorch(四) —— 可视化工具 Visdom