当前位置:网站首页>Pytorch中自制数据集进行Dataset重写
Pytorch中自制数据集进行Dataset重写
2022-07-07 15:40:00 【AI炮灰】
通过上一篇博文,我们可以获得一下文件的数据如下所示:
所以自制数据集的流程如下:
(1)生成csv或者txt文件
见我上一篇博客:深度学习-制作自己的数据集_AI炮灰的博客-CSDN博客
(2)重写Dataset
(3)生成DataLoader()
(4)迭代数据
(2)(3)(4)步完整代码如下所示;
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import cv2 as cv
class diff_motion_dataset(Dataset):
def __init__(self, dataset_dir, csv_path, resize_shape): # 初始化以后该初始化函数会自行调用
# init方法一般要编写数据的transformer、数据的基本参数
self.dataset_dir = dataset_dir
self.csv_path = csv_path
self.shape = resize_shape
# 读取我们生成的csv文件
self.df = pd.read_csv(self.csv_path, encoding='utf-8')
self.transformer = transforms.Compose([
transforms.Resize(self.shape),
transforms.ToTensor(), # 把PIL核np.array格式的图像转化为Tensor
])
def __len__(self): # 返回数据规模
return len(self.df)
def __getitem__(self, idx): # getitem, idx = index 就是数据样本的下标.特别提醒下面要先把列filename和label取出来再进行idx顺序读取不然就会报错
x_train = cv.imread(self.df['filepath'][idx]) # 读取idx行,filename列的数据(也即是所有图像),然后传入到transformer里面,它会对图像进行resize和toTensor
y_train = self.df['label'][idx] # traindataLoader后面会自动把label转化为tensor
return x_train, y_train # 返回的是单条数据不是df里面的所有数据
data_ds = diff_motion_dataset("F:/reshape_images", "F:/reshape_images/motion_data.csv", (256, 256))
# print(len(data_ds))
# 数据划分
num_sample = len(data_ds)
train_percent = 0.8
train_num = int(train_percent*num_sample)
test_num = num_sample - train_num
train_ds, test_ds = random_split(data_ds, [train_num, test_num])
# print(len(train_ds))
# 3.生成DataLoader().使得数据可以迭代,其次可以将数据分成许多的batch以及shuffer、nun_worker多线程
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=4, shuffle=True)
# # 迭代数据
for x_train, y_train in iter(train_dl):
print(x_train.shape)
print(y_train.shape)
break
如果需要自己定义的模型进行自制数据集训练,把定义的模型进行如下调用:
不同格式的是数据集的制作与加载可以见:
边栏推荐
- Siggraph 2022 best technical paper award comes out! Chen Baoquan team of Peking University was nominated for honorary nomination
- 【TPM2.0原理及应用指南】 9、10、11章
- 【网络攻防原理与技术】第1章:绪论
- Solidity函数学习
- 【网络攻防原理与技术】第4章:网络扫描技术
- 百度地图自定义样式向右拖拽导致全球地图经度0度无法正常显示
- 跟奥巴马一起画方块(Lua)
- 【可信计算】第十次课:TPM密码资源管理(二)
- Rpcms method of obtaining articles under the specified classification
- Audio device strategy audio device output and input selection is based on 7.0 code
猜你喜欢
Toast will display a simple prompt message on the program interface
【重新理解通信模型】Reactor 模式在 Redis 和 Kafka 中的应用
ViewSwitcher的功能和用法
[video / audio data processing] Shanghai daoning brings you elecard download, trial and tutorial
【网络攻防原理与技术】第6章:特洛伊木马
【信息安全法律法規】複習篇
Please insert the disk into "U disk (H)" & unable to access the disk structure is damaged and cannot be read
Sator launched Web3 game "satorspace" and launched hoobi
imageswitcher的功能和用法
Create dialog style windows with popupwindow
随机推荐
How to implement safety practice in software development stage
VSCode关于C语言的3个配置文件
LeetCode 497(C#)
A tour of grpc:03 - proto serialization / deserialization
LeetCode1051(C#)
Functions and usage of tabhost tab
Please insert the disk into "U disk (H)" & unable to access the disk structure is damaged and cannot be read
Notification is the notification displayed in the status bar of the phone
[fan Tan] after the arrival of Web3.0, where should testers go? (ten predictions and suggestions)
Mysql 索引命中级别分析
《世界粮食安全和营养状况》报告发布:2021年全球饥饿人口增至8.28亿
阿富汗临时政府安全部队对极端组织“伊斯兰国”一处藏匿点展开军事行动
麒麟信安操作系统衍生产品解决方案 | 存储多路径管理系统,有效提高数据传输可靠性
datepicket和timepicket,日期、时间选择器的功能和用法
本周小贴士#140:常量:安全习语
Siggraph 2022 best technical paper award comes out! Chen Baoquan team of Peking University was nominated for honorary nomination
第2章搭建CRM项目开发环境(数据库设计)
企业经营12法的领悟
L1-019 谁先倒(Lua)
What is cloud computing?