当前位置:网站首页>PyTorch学习笔记08——加载数据集
PyTorch学习笔记08——加载数据集
2022-07-31 05:16:00 【qq_50749521】
PyTorch学习笔记08——加载数据集
在上一次的糖尿病数据集中,我们是使用整个数据集input计算的。这次考虑mini_batch的输入方式。
三个概念:
epoch:所有训练样本全部轮一遍叫做一个epoch
Batch-Size:批量训练时,每批量包含的样本个数
iteration:每批量轮一遍叫做一个iteration
比如一个数据集有200个样本,把他分成40块,每块就有5个样本。
那么batch = 40, batch_size = 5。
训练的时候,按每块训练,把一块的5个样本轮一遍,叫做1个itearion。
这40块都轮一遍,就是200个样本都训练了一遍,叫做1个epoch。
DataLoader:一种数据集加载方式
他能帮我们做什么?我们要做小批量训练,为了提高训练的随机性,我们可以对数据集进行shuffle。
当把一个支持索引和长度可知的数据集送到dataloader里,就可以自动对dataset进行小批量生成。
dataset -> Shuffle ->
Loader
如何定义你的数据集Dataset?
提供一个概念性代码:
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self):
pass
def __getitem__(self, index):
pass
def __len__(self):
pass
dataset = DiabetesDataset()
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 2)
Pytorch提供了一种Dataset类,这是一种抽象类,我们知道抽象类不能被实例化,但可以被继承。
- 上面的DiabetesDataset就是我们自己写的一个继承Dataset的类。表达式getitem、len都是魔法函数,分别返回值和数据集的长度。
- 实例化DiabetesDataset后,通过Dataloader来自动创建小批量数据集。 这里用batch_size, shuffle,
process number来初始化。
batch_size = 32确定每批量样本数,shuffle = True确认打乱数据集,num_workers = 2表示将来读这个数据的时候,构成mini_batch的时候,一般会使用多线程。这里就是用了两个线程并行读取数据。CPU核心多可以设置的多一点。
这样,我们就成功得到了想要的数据集方式train_loader,可以开始训练了~
for epoch in range(100):
for index, data in enumerate(train_loader, 0):
#index 返回的是batch(总样本数/batch_size)索引,data返回(inputs, labels)的张量数据
对糖尿病数据作小批量训练,整个代码如下:
import torch
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter = ',', dtype = np.float32)#直接读取存到内存里
self.len = xy.shape[0] #样本数量
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]#返回元组
def __len__(self):
return self.len
dataset = DiabetesDataset('F:\ASR-source\Dataset\diabetes.csv.gz')
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0)
batch_size = 32
batch = np.round(dataset.__len__() / batch_size)
batch
24.0
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
return x
mymodel = Model()
criterion = torch.nn.BCELoss(reduction='mean')
optimizer = torch.optim.SGD(mymodel.parameters(), lr = 0.01)
epoch_list = []
loss_list = []
sum_loss = 0
if __name__ == '__main__':
for epoch in range(100):
for index, data in enumerate(train_loader, 0): #train_loader存的是分割组合后的小批量训练样本和对应的标签
inputs, labels = data #inputs labels都是张量
y_pred = mymodel(inputs)
loss = criterion(y_pred, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
sum_loss += loss.item()
print('epoch = ', epoch + 1,'index = ', index+1, 'loss = ', loss.item())
epoch_list.append(epoch)
loss_list.append(sum_loss/batch)
print(sum_loss/batch)
sum_loss = 0
epoch = 1 index = 1 loss = 0.6523504257202148
epoch = 1 index = 2 loss = 0.6662447452545166
epoch = 1 index = 3 loss = 0.6510850191116333
epoch = 1 index = 4 loss = 0.622829794883728
epoch = 1 index = 5 loss = 0.6272122263908386
epoch = 1 index = 6 loss = 0.5990191102027893
epoch = 1 index = 7 loss = 0.6213780045509338
epoch = 1 index = 8 loss = 0.6761874556541443
epoch = 1 index = 9 loss = 0.6133689880371094
epoch = 1 index = 10 loss = 0.6413829326629639
epoch = 1 index = 11 loss = 0.6246744394302368
epoch = 1 index = 12 loss = 0.6163585782051086
epoch = 1 index = 13 loss = 0.599936306476593
epoch = 1 index = 14 loss = 0.6216733455657959
epoch = 1 index = 15 loss = 0.6504020094871521
epoch = 1 index = 16 loss = 0.6451072096824646
epoch = 1 index = 17 loss = 0.6215073466300964
epoch = 1 index = 18 loss = 0.6641662120819092
epoch = 1 index = 19 loss = 0.6364893317222595
epoch = 1 index = 20 loss = 0.6020426154136658
epoch = 1 index = 21 loss = 0.617006778717041
epoch = 1 index = 22 loss = 0.653681218624115
epoch = 1 index = 23 loss = 0.5835389494895935
epoch = 1 index = 24 loss = 0.6029499173164368
0.6296080400546392
epoch = 2 index = 1 loss = 0.6385740637779236
epoch = 2 index = 2 loss = 0.6440627574920654
epoch = 2 index = 3 loss = 0.6580216288566589
........
换成下面这个模型,迭代俩百次,结果是这样的
class Model(torch.nn.Module):
def __init__(self):
super(Model, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.relu = torch.nn.ReLU()
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.relu(self.linear2(x))
x = self.sigmoid(self.linear3(x))#注意最后一步不能使用relu,避免无法计算梯度
return x
边栏推荐
- cocos2d-x-3.2 create project method
- Navicat从本地文件中导入sql文件
- VS通过ODBC连接MYSQL(一)
- Android软件安全与逆向分析阅读笔记
- TransactionTemplate 事务编程式写法
- 使用 OpenCV 提取图像的 HOG、SURF 及 LBP 特征 (含代码)
- ERROR Error: No module factory availabl at Object.PROJECT_CONFIG_JSON_NOT_VALID_OR_NOT_EXIST ‘Error
- 如何修改数据库密码
- kotlin 插件更新到1.3.21
- quick-3.5 无法使用模拟器修改
猜你喜欢
RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录
After unicloud is released, the applet prompts that the connection to the local debugging service failed. Please check whether the client and the host are under the same local area network.
通信原理——纠错编码 | 汉明码(海明码)手算详解
如何修改数据库密码
js中的全局作用域与函数作用域
How MySQL - depots table?A look at will understand
Notes on creating a new virtual machine in Hyper-V
一文速学-玩转MySQL获取时间、格式转换各类操作方法详解
为什么bash中的read要配合while才能读取/dev/stdin的内容
The feign call fails, JSON parse error Illegal character ((CTRL-CHAR, code 31)) only regular white space (r
随机推荐
flutter arr dependencies
Android software security and reverse analysis reading notes
网页截图与反向代理
quick-3.5 lua调用c++
Understanding of objects and functions in js
RuntimeError: CUDA error: no kernel image is available for execution on the device问题记录
小米手机短信定位服务激活失败
cocos2d-x-3.2 image graying effect
Tencent Cloud Lightweight Server deletes all firewall rules
jenkins +miniprogram-ci upload WeChat applet with one click
一文速学-玩转MySQL获取时间、格式转换各类操作方法详解
npm WARN config global `--global`, `--local` are deprecated. Use `--location solution
Chinese garbled solution in UTF-8 environment in Powershell
[Cloud Native] What should I do if SQL (and stored procedures) run too slowly?
【解决问题】RuntimeError: The size of tensor a (80) must match the size of tensor b (56) at non-singleton
podspec 校验依赖出错问题 pod lib lint ,需要指定源
configure:error no SDL library found
sql add default constraint
configure:error no SDL library found
多元线性回归方程原理及其推导