当前位置:网站首页>PyTorch学习笔记08——加载数据集
PyTorch学习笔记08——加载数据集
2022-07-31 05:16:00 【qq_50749521】
PyTorch学习笔记08——加载数据集

在上一次的糖尿病数据集中,我们是使用整个数据集input计算的。这次考虑mini_batch的输入方式。
三个概念:
epoch:所有训练样本全部轮一遍叫做一个epoch
Batch-Size:批量训练时,每批量包含的样本个数
iteration:每批量轮一遍叫做一个iteration
比如一个数据集有200个样本,把他分成40块,每块就有5个样本。
那么batch = 40, batch_size = 5。
训练的时候,按每块训练,把一块的5个样本轮一遍,叫做1个itearion。
这40块都轮一遍,就是200个样本都训练了一遍,叫做1个epoch。
DataLoader:一种数据集加载方式
他能帮我们做什么?我们要做小批量训练,为了提高训练的随机性,我们可以对数据集进行shuffle。
当把一个支持索引和长度可知的数据集送到dataloader里,就可以自动对dataset进行小批量生成。
dataset -> Shuffle ->
Loader

如何定义你的数据集Dataset?
提供一个概念性代码:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
dataset = DiabetesDataset()
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 2)
Pytorch提供了一种Dataset类,这是一种抽象类,我们知道抽象类不能被实例化,但可以被继承。
- 上面的DiabetesDataset就是我们自己写的一个继承Dataset的类。表达式getitem、len都是魔法函数,分别返回值和数据集的长度。
- 实例化DiabetesDataset后,通过Dataloader来自动创建小批量数据集。 这里用batch_size, shuffle,
process number来初始化。
batch_size = 32确定每批量样本数,shuffle = True确认打乱数据集,num_workers = 2表示将来读这个数据的时候,构成mini_batch的时候,一般会使用多线程。这里就是用了两个线程并行读取数据。CPU核心多可以设置的多一点。
这样,我们就成功得到了想要的数据集方式train_loader,可以开始训练了~
for epoch in range(100):
for index, data in enumerate(train_loader, 0):
#index 返回的是batch(总样本数/batch_size)索引,data返回(inputs, labels)的张量数据
对糖尿病数据作小批量训练,整个代码如下:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter = ',', dtype = np.float32)#直接读取存到内存里
self.len = xy.shape[0] #样本数量
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]#返回元组
def __len__(self):
return self.len
dataset = DiabetesDataset('F:\ASR-source\Dataset\diabetes.csv.gz')
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0)
batch_size = 32
batch = np.round(dataset.__len__() / batch_size)
batch
24.0
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
return x
mymodel = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(mymodel.parameters(), lr = 0.01)
epoch_list = []
loss_list = []
sum_loss = 0
if __name__ == '__main__':
for epoch in range(100):
for index, data in enumerate(train_loader, 0): #train_loader存的是分割组合后的小批量训练样本和对应的标签
inputs, labels = data #inputs labels都是张量
y_pred = mymodel(inputs)
loss = criterion(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('epoch = ', epoch + 1,'index = ', index+1, 'loss = ', loss.item())
epoch_list.append(epoch)
loss_list.append(sum_loss/batch)
print(sum_loss/batch)
sum_loss = 0
epoch = 1 index = 1 loss = 0.6523504257202148
epoch = 1 index = 2 loss = 0.6662447452545166
epoch = 1 index = 3 loss = 0.6510850191116333
epoch = 1 index = 4 loss = 0.622829794883728
epoch = 1 index = 5 loss = 0.6272122263908386
epoch = 1 index = 6 loss = 0.5990191102027893
epoch = 1 index = 7 loss = 0.6213780045509338
epoch = 1 index = 8 loss = 0.6761874556541443
epoch = 1 index = 9 loss = 0.6133689880371094
epoch = 1 index = 10 loss = 0.6413829326629639
epoch = 1 index = 11 loss = 0.6246744394302368
epoch = 1 index = 12 loss = 0.6163585782051086
epoch = 1 index = 13 loss = 0.599936306476593
epoch = 1 index = 14 loss = 0.6216733455657959
epoch = 1 index = 15 loss = 0.6504020094871521
epoch = 1 index = 16 loss = 0.6451072096824646
epoch = 1 index = 17 loss = 0.6215073466300964
epoch = 1 index = 18 loss = 0.6641662120819092
epoch = 1 index = 19 loss = 0.6364893317222595
epoch = 1 index = 20 loss = 0.6020426154136658
epoch = 1 index = 21 loss = 0.617006778717041
epoch = 1 index = 22 loss = 0.653681218624115
epoch = 1 index = 23 loss = 0.5835389494895935
epoch = 1 index = 24 loss = 0.6029499173164368
0.6296080400546392
epoch = 2 index = 1 loss = 0.6385740637779236
epoch = 2 index = 2 loss = 0.6440627574920654
epoch = 2 index = 3 loss = 0.6580216288566589
........

换成下面这个模型,迭代俩百次,结果是这样的
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#注意最后一步不能使用relu,避免无法计算梯度
return x

边栏推荐
猜你喜欢

如何修改数据库密码

js中的全局作用域与函数作用域

Take you to understand the MySQL isolation level, what happens when two transactions operate on the same row of data at the same time?

Understanding of js arrays

使用 OpenCV 提取图像的 HOG、SURF 及 LBP 特征 (含代码)

禅道安装及使用教程

为什么bash中的read要配合while才能读取/dev/stdin的内容

Navicat从本地文件中导入sql文件

人脸识别AdaFace学习笔记

For penetration testing methods where the output point is a timestamp (take Oracle database as an example)
随机推荐
Nmap的下载与安装
How MySQL - depots table?A look at will understand
微信小程序源码获取与反编译方式
用pytorch里的children方法自定义网络
JS写一段代码,判断一个字符串中出现次数最多的字符串,并统计出现的次数JS
js中流程控制语句
powershell统计文件夹大小
unicloud cloud development record
quick-3.5 无法正常显示有混合纹理的csb文件
SSH自动重连脚本
朴素贝叶斯文本分类(代码实现)
Flutter mixed development module dependencies
Sqlite column A data is copied to column B
Artifact SSMwar exploded Error deploying artifact.See server log for details
Android software security and reverse analysis reading notes
CNN的一点理解
计算图像数据集均值和方差
ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
通信原理——纠错编码 | 汉明码(海明码)手算详解
Numpy常用函数