当前位置:网站首页>Pytorch中自制数据集进行Dataset重写
Pytorch中自制数据集进行Dataset重写
2022-07-07 15:40:00 【AI炮灰】
通过上一篇博文,我们可以获得一下文件的数据如下所示:
所以自制数据集的流程如下:
(1)生成csv或者txt文件
见我上一篇博客:深度学习-制作自己的数据集_AI炮灰的博客-CSDN博客
(2)重写Dataset
(3)生成DataLoader()
(4)迭代数据
(2)(3)(4)步完整代码如下所示;
import pandas as pd
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
import cv2 as cv
class diff_motion_dataset(Dataset):
def __init__(self, dataset_dir, csv_path, resize_shape): # 初始化以后该初始化函数会自行调用
# init方法一般要编写数据的transformer、数据的基本参数
self.dataset_dir = dataset_dir
self.csv_path = csv_path
self.shape = resize_shape
# 读取我们生成的csv文件
self.df = pd.read_csv(self.csv_path, encoding='utf-8')
self.transformer = transforms.Compose([
transforms.Resize(self.shape),
transforms.ToTensor(), # 把PIL核np.array格式的图像转化为Tensor
])
def __len__(self): # 返回数据规模
return len(self.df)
def __getitem__(self, idx): # getitem, idx = index 就是数据样本的下标.特别提醒下面要先把列filename和label取出来再进行idx顺序读取不然就会报错
x_train = cv.imread(self.df['filepath'][idx]) # 读取idx行,filename列的数据(也即是所有图像),然后传入到transformer里面,它会对图像进行resize和toTensor
y_train = self.df['label'][idx] # traindataLoader后面会自动把label转化为tensor
return x_train, y_train # 返回的是单条数据不是df里面的所有数据
data_ds = diff_motion_dataset("F:/reshape_images", "F:/reshape_images/motion_data.csv", (256, 256))
# print(len(data_ds))
# 数据划分
num_sample = len(data_ds)
train_percent = 0.8
train_num = int(train_percent*num_sample)
test_num = num_sample - train_num
train_ds, test_ds = random_split(data_ds, [train_num, test_num])
# print(len(train_ds))
# 3.生成DataLoader().使得数据可以迭代,其次可以将数据分成许多的batch以及shuffer、nun_worker多线程
train_dl = DataLoader(train_ds, batch_size=4, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=4, shuffle=True)
# # 迭代数据
for x_train, y_train in iter(train_dl):
print(x_train.shape)
print(y_train.shape)
break
如果需要自己定义的模型进行自制数据集训练,把定义的模型进行如下调用:
不同格式的是数据集的制作与加载可以见:
边栏推荐
- Ratingbar的功能和用法
- 百度地图自定义样式向右拖拽导致全球地图经度0度无法正常显示
- Rpcms method of obtaining articles under the specified classification
- DNS series (I): why does the updated DNS record not take effect?
- On Apache Doris Fe processing query SQL source code analysis
- L1-027 出租(Lua)
- 《世界粮食安全和营养状况》报告发布:2021年全球饥饿人口增至8.28亿
- calendarview日历视图组件的功能和用法
- MySQL implements the query of merging two fields into one field
- 鲲鹏开发者峰会2022 | 麒麟信安携手鲲鹏共筑计算产业新生态
猜你喜欢
99%的人都不知道|私有化部署还永久免费的即时通讯软件!
Sator launched Web3 game "satorspace" and launched hoobi
【网络攻防原理与技术】第6章:特洛伊木马
Biped robot controlled by Arduino
【OKR目标管理】案例分析
Share the latest high-frequency Android interview questions, and take you to explore the Android event distribution mechanism
alertDialog創建对话框
Linux 安装mysql8.X超详细图文教程
企业即时通讯软件是什么?它有哪些优势呢?
【网络攻防原理与技术】第5章:拒绝服务攻击
随机推荐
[Huang ah code] Why do I suggest you choose go instead of PHP?
【可信计算】第十一次课:TPM密码资源管理(三) NV索引与PCR
textSwitch文本切换器的功能和用法
本周小贴士#135:测试约定而不是实现
Mysql 索引命中级别分析
LeetCode 497(C#)
serachview的功能和用法
[fan Tan] those stories that seem to be thinking of the company but are actually very selfish (I: building wheels)
本周小贴士131:特殊成员函数和`= default`
LeetCode 535(C#)
MySQL implements the query of merging two fields into one field
如何在软件研发阶段落地安全实践
【OKR目标管理】价值分析
Notes on installing MySQL in centos7
Siggraph 2022 best technical paper award comes out! Chen Baoquan team of Peking University was nominated for honorary nomination
Examen des lois et règlements sur la sécurité de l'information
麒麟信安操作系统衍生产品解决方案 | 存储多路径管理系统,有效提高数据传输可靠性
【可信计算】第十次课:TPM密码资源管理(二)
使用popupwindow創建对话框风格的窗口
Nerf: the ultimate replacement for deepfake?