当前位置:网站首页>【Pytorch学习笔记】11.取Dataset的子集、给Dataset打乱顺序的方法(使用Subset、random_split)
【Pytorch学习笔记】11.取Dataset的子集、给Dataset打乱顺序的方法(使用Subset、random_split)
2022-08-05 05:15:00 【takedachia】
(pytorch版本:1.2)
我们在使用Dataset定义好数据集后,在处理数据集时经常会碰到这些问题:如何把Dataset拆分成两个子集(如用于指定训练集和测试集、k折交叉验证等)?如何进行随机拆分?如何打乱一个Dataset内数据的顺序?
Dataset取子集、拆分
使用 torch.utils.data.Subset() 可对数据集取子集。
传入一个Dataset,一个序列切片indices,即可得到一个子集。
1.我们可以传入一个range():
indices = range(18353) # 取标号为第0个到第18352个数据
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
2.可以取区间:
indices = range(18353, 27153) # 取标号为第18353个到第27152个数据
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
3.可以传入一个List。有List就可以用列表生成式:
indices = [x for x in range(1234)]
sub_imgs = torch.utils.data.Subset(imgs, indices)
len(imgs), len(sub_imgs)
打乱Dataset内数据的顺序
我们可以直接传入一个乱序的index就可以达到数据集乱序的目的:
from torch import randperm
lenth = randperm(len(Leaf_dataset_train)).tolist() # 生成乱序的索引
rand_train = torch.utils.data.Subset(imgs, lenth)
# 显示一下第一张图片、原标号
X = rand_train[0]
plt.imshow(torch.transpose(X[0],0,2)), lenth[0]
我们在打乱顺序后就可以取子集对数据集进行k折交叉验证等行为。
随机拆分Dataset
使用 torch.utils.data.random_split() 可直接对数据集进行拆分,随机分成多份。
可以传入一个List,注意传入的List序列中包含每个子集的大小(数量),且这几个数的和必须等于传入Dataset的长度。
示例:
# 这里Leaf_dataset_train的大小必须等于 17000+1353
train_set, test_set = torch.utils.data.random_split(Leaf_dataset_train, [17000, 1353])
print(len(train_set), len(test_set))
边栏推荐
- flink部署操作-flink on yarn集群安装部署
- day9-字符串作业
- Lecture 3 Gradient Tutorial Gradient Descent and Stochastic Gradient Descent
- In Opencv, imag=cv2.cvtColor(imag,cv2.COLOR_BGR2GRAY) error: error:!_src.empty() in function 'cv::cvtColor'
- 【技能】长期更新
- SSL 证书签发详细攻略
- 【过一下15】学习 lstm的一周
- 第二讲 Linear Model 线性模型
- 【NFT网站】教你制作开发NFT预售网站官网Mint作品
- Oracle压缩表修改字段的处理方法
猜你喜欢
Flink Broadcast 广播变量
Lecture 2 Linear Model Linear Model
Pycharm中使用pip安装第三方库安装失败:“Non-zero exit code (2)“的解决方法
flink中文文档-目录v1.4
DOM and its applications
The software design experiment four bridge model experiment
[Go through 4] 09-10_Classic network analysis
Flink Distributed Cache 分布式缓存
【过一下3】卷积&图像噪音&边缘&纹理
2022年中总结关键词:裁员、年终奖、晋升、涨薪、疫情
随机推荐
SQL(二) —— join窗口函数视图
Flink和Spark中文乱码问题
Flink HA配置
学习总结week3_1函数
解决:Unknown column ‘id‘ in ‘where clause‘ 问题
Flink HA安装配置实战
[Go through 7] Notes from the first section of the fully connected neural network video
Redux
Lecture 5 Using pytorch to implement linear regression
Pandas(五)—— 分类数据、读取数据库
大型Web网站高并发架构方案
拿出接口数组对象中的所有name值,取出同一个值
对数据排序
浅谈Servlet生命周期
如何停止flink job
学习总结week2_3
Mesos learning
npm搭建本地服务器,直接运行build后的目录
【技能】长期更新
NodeJs接收上传文件并自定义保存路径