当前位置:网站首页>《nlp入门+实战:第七章:pytorch中数据集加载和自带数据集的使用》
《nlp入门+实战:第七章:pytorch中数据集加载和自带数据集的使用》
2022-07-29 21:08:00 【ZNineSun】
文章目录
上一篇: 《nlp入门+实战:第六章:常见优化器算法的介绍》
本章代码链接:
- https://gitee.com/ninesuntec/nlp-entry-practice/blob/master/code/7.pytorch中数据集加载.py
- https://gitee.com/ninesuntec/nlp-entry-practice/blob/master/code/7.pytorch自带数据集的使用.py
本章数据集地址:
1.模型中使用数据加载器的目的
在前面的线性回归模型中,我们使用的数据很少,所以直接把全部数据放到模型中去使用。
但是在深度学习中,数据量通常是都非常多,非常大的,如此大量的数据,不可能一次性的在模型中进行向前的计算和反向传播,经常我们会对整个数据进行随机的打乱顺序,把数据处理成一个个的batch,同时还会对数据进行预处理。
所以,接下来我们来学习pytorch中的数据加载的方法
2.数据集类
2.1 Dataset基类介绍
在torch中提供了数据集的基类torch.uti1s.data. Dataset,继承这个基类,我们能够非常快速的实现对数据的加载。
torch.utils.data. Dataset的源码如下:
可知:我们需要在自定义的数据集类中继承Dataset类,同时还需要实现两个方法:
- 1._len_方法,能够实现通过全局的len()方法获取其中的元素个数
- 2._getitem_方法,能够通过传入索引的方式获取数据,例如通过dataset[i]获取其中的第i条数据
2.2 数据加载案例
下面通过—个例子来看看如何使用Dataset来加载数据
数据来源:https://archive.ics.uci.edu/ml/datasets/SMS+Spam+Collection/
不过这个数据集好像已经不能用了,但是类似的Kaggle上还有很多,大家也可以尝试下载这一个:
https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset
数据介绍: SMS Spam Collection是用于强扰短信识别的经典数据集,完全来自真实短信内容,包括4831条正常短信和747条骚扰短信。正常短信和骚扰短信保存在一个文本文件中。每行完整记—条短信内容,每行开头通过ham和spam标识正常短信和骚扰短信
数据实例:
实现如下:
import torch
from torch.utils.data import Dataset
data_path = r"C:\Users\NineSun\Desktop\archive\spam.txt" # r表示后面是一个字符串,无需转义
# 完成数据集类
class MyDdataSet(Dataset):
def __init__(self):
self.lines = open(data_path, 'r', encoding='UTF-8').readlines()
def __getitem__(self, index):
# 获取索引对应位置的数据
return self.lines[index]
def __len__(self):
# 返回数据的总数量
return len(self.lines)
if __name__ == '__main__':
my_dataset = MyDdataSet()
print(my_dataset[0])
print(len(my_dataset))

之后对Dataset进行实例化。可以迭代获取其中的数据:
my_dataset = MyDdataSet()
for i in range(len(my_dataset)):
print(i,my_dataset[i])
3.迭代数据集
使用上述的方法能够进行数据的读取,但是其中还有很多内容没有实现:
- 批处理数据(Batching the data)
- 打乱数据(Shuffling the data)
- 使用多线程multiprocessing并行加载数据。
在pytorch中torch.uti1s.data. DataLoader提供了上述的所用方法DataLoader的使用方法示例:
from torch.utils.data import DataLoader
my_dataset = MyDdataSet()
# dataset:实例化之后的数据集;batch_size:batch的大小,一个batch包含10个样本数据;
# shuffle:表示是否打乱数据的顺序;num_workers:表示加载数据时启用线程的数量
data_loader = DataLoader(dataset=my_dataset, batch_size=10, shuffle=True, num_workers=2)
# 遍历,获取每个batch的结果
if __name__ == '__main__':
for i in data_loader:
print(i)
其中参数含义:
- 1.dataset:提前定义的dataset的实例
- 2.batch_size:传入数据的batch的大小,常用128,256等等
- 3.shuffle: bool类型,表示是否在每次获取数据的时候提前打乱数据
- 4.num_workers :加载数据的线程数
注意:
- 1.len(dataset)=数据集的样本数
- 2.1en(dataloader) = math.cei1(样本数/batch_size即向上取整)
print(len(data_loader))
print(len(my_dataset))

4 pytorch自带的数据集
pytorch中自带的数据集由两个上层api提供,分别是torchvision和torchtext;
其中:
- 1.torchvision提供了对图片数据处理相关的api和数据
- 数据位置:torchvision.datasets,例如: torchvision.datasets.MNIST(手写数字图片数据)
- 2.torchtext提供了对文本数据处理相关的API和数据
- 数据位置: torchtext.datas ets ,例如: torchtext.datasets.IMDB(电影评论文本数据)
下面我们以Mnist手写数字为例,来看看pytorch如何加载其中自带的数据集
使用方法和之前一样:
- 1.准备好Dataset实例
- 2.把dataset交给dataloder 打乱顺序,组成batch
在进行下面的内容之前,请大家把torchvision和torchtext安装一下,安装办法也很简单,我使用的是anaconda安装的,安装命令如下:
pip install torchvision
pip install torchtext
记得切换到你所对应的环境
如果上面这种方式下载过慢,可以尝试下面这条指令
pip install --upgrade torch torchvision -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install --upgrade torch torchtext -i https://pypi.tuna.tsinghua.edu.cn/simple
如果没安装成功,请大家自己百度一下吧。
4.1 torchversion.datasets
torchversoin. datasets 中的数据集类(比如torchvision.datasets .MNIST),都是继承自Dataset
意味着:直接对torchvision.datasets.MNIST进行实例化就可以得到Dataset的实例但是MNIST API中的参数需要注意一下:
torchvision.datasets.MNIST(root=’ /fi1es / ’ ,train=True,down1oad=True,transform=)
- 1.root参数表示数据存放的位置
- 2.train: bool类型,表示是使用训练集的数据还是测试集的数据
- 3.download:bool类型,表示是否需要下载数据到root目录
- 4.transform:实现的对图片的处理函数
4.2 MNIST数据集的介绍
数据集的原始地址: http://yann.lecun.com/exdb/mnist/
MNIST是由Yann Lecun等人提供的免费的图像识别的数据集,其中包括60000个训练样本和10000个测试样本,其中图拍了的尺寸已经进行的标准化的处理,都是黑白的图像,大小为28×28
执行代码,下载数据,观察数据类型:
from torchvision.datasets import MNIST
mninst = MNIST(root='./data', train=True, download=True)
print(mninst[0])
运行以后,会在data目录下生成以下数据集

可以看出其中数据集返回了两条数据,可以猜测为图片的数据和目标值
返回值的第0个为Image类型,可以调用show()方法打开,发现为手写数字5
from torchvision.datasets import MNIST
mninst = MNIST(root='./data', train=True, download=True)
print(mninst[0])
img=mninst[0][0]
img.show()

由上可知:返回值为(图片,目标值),这个结果也可以通过观察源码得到
边栏推荐
- Setinel 原理简介
- 错误解决:Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255]
- 什么是数据安全性?
- 惠普服务器硬盘指示灯不亮或显示蓝色
- 怎么实现您的个人知识库?
- 全系都更换带T四缸,安全、舒适一个不落
- Analysis of Crypto in Pi 2
- 《张卫国的夏天》欢乐来袭,黄磊、刘奕君携手演绎“冤种”兄弟
- ALBERT: A Lite BERT for Self-supervised Learning of Language Representations
- sizeof和strlen的区别(strlen和sizeof的用法)
猜你喜欢

GET_ENTITYSET Method Implementation Guide for SAP ABAP OData Service Data Provider Class

怎么实现您的个人知识库?

惠普服务器硬盘指示灯不亮或显示蓝色

全景教程丨VR全景拍摄如何拍摄日出和日落的场景?

ALBERT: A Lite BERT for Self-supervised Learning of Language Representations

数字孪生万物可视 | 联接现实世界与数字空间

全系都更换带T四缸,安全、舒适一个不落

MySQL Data Query - Simple Query

TCP协议详解

优惠券系统设计思想
随机推荐
根据昵称首字母生成头像
【无标题】
1. Promise usage in JS, 2. The concept and usage of closures, 3. The difference between the four methods and areas of object creation, 4. How to declare a class
SQL教程之性能不仅仅是查询
Chrome浏览器打印flash log
[ACTF2020 Freshman Competition]Exec 1
02-SDRAM:自动刷新
分析少年派2中的Crypto
940. Different subsequences II
网安学习-内网渗透2
仿Modbus消息帧进行通信
378. 有序矩阵中第 K 小的元素
MySQL Data Query - Union Query
【ORM框架:Sequelize的查询】
Qualcomm WLAN framework learning (31) -- Power save
普洛斯荣获两项“数据中心绿色等级评估”5A级认证
如何优雅的自定义 ThreadPoolExecutor 线程池
Minimal jvm source code analysis
LeetCode 593 有效的正方形[数学] HERODING的LeetCode之路
容器网络硬核技术内幕 (24) 知微知彰,知柔知刚 (上)