当前位置:网站首页>Pytorch dataloader implements minibatch (incomplete)
Pytorch dataloader implements minibatch (incomplete)
2022-07-03 05:48:00 【code bean】
Take the book back 《pytorch Build the simplest version of neural network 》 Not used last time miniBatch, Input all the data for training at one time .
This pass DataLoader Realization miniBatch
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset,
batch_size=749,
shuffle=False,
num_workers=1)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() This function will be called in !
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
# model Is callable ! Realized __call__()
model = Model()
# Specify the loss function
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum: Sum up mean: Averaging
criterion = torch.nn.BCELoss(reduction='mean') # Two class cross entropy loss function
# -- Specify optimizer ( In fact, it is the algorithm of gradient descent , be responsible for ), The optimizer and model Associated
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # This one is very accurate , Other things are not good at all
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
# for epoch in range(5000):
# y_pred = model(x_data) # Directly put the whole test data into
# loss = criterion(y_pred, y_data)
# print(epoch, loss.item())
# optimizer.zero_grad() # It will automatically find all the w and b Clear ! The role of the optimizer ( Why is this put in loss.backward() You can't clear it later ?)
# loss.backward()
# optimizer.step() # It will automatically find all the w and b updated , The role of the optimizer !
#
if __name__ == '__main__':
for epoch in range(1000):
for i, data in enumerate(train_loader, 0):
# 1. Prepare data
inputs, labels = data
# 2. Forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# 3. Backward
optimizer.zero_grad()
loss.backward()
# 4. Update
optimizer.step()
# test
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # forecast
print('y_pred = ', y_test.data)
# Compare the predicted results with the real results
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
Intercept key code :
train_loader = DataLoader(dataset=dataset,
batch_size=749,
shuffle=False,
num_workers=1)batch_size The size of each batch of training , If the total number of data pieces is 7490 strip , So at this time train_loader Will produce 10 individual batch.shuffle It's a shuffle , It is to disorder the order of data and then put the data into 10 individual batch in .

And now why do I , To put batch_size Set to 29, And set it not to shuffle . Because I want to compare with the previous process .
Because there are only 749 data , Then there will only be one batch, One batch Inside is all the data without disordering the order . This is the same as the previous training process .
I did this because I found , Used DataLoader After that, the training process became unusually slow , After changing to the above configuration , I thought the speed would be the same , But it's actually slow 10 More than !
What hurts more is , If batch_size Make it smaller ,loss There is no way to converge directly , If there are too many cyclic tests, an error is reported :
return _winapi.DuplicateHandle( PermissionError: [WinError 5] Access denied
this DataLoader How to use it ? Which God will give you some advice ?
边栏推荐
- The server data is all gone! Thinking caused by a RAID5 crash
- Es 2022 officially released! What are the new features?
- NG Textarea-auto-resize
- How to create and configure ZABBIX
- The programmer shell with a monthly salary of more than 10000 becomes a grammar skill for secondary school. Do you often use it!!!
- Azure file synchronization of altaro: the end of traditional file servers?
- Solve the 1251 client does not support authentication protocol error of Navicat for MySQL connection MySQL 8.0.11
- AtCoder Beginner Contest 258(A-D)
- 1. 两数之和
- 牛客网 JS 分隔符
猜你喜欢
![[teacher Zhao Yuqiang] RDB persistence of redis](/img/cc/5509b62756dddc6e5d4facbc6a7c5f.jpg)
[teacher Zhao Yuqiang] RDB persistence of redis

配置xml文件的dtd

How to set up altaro offsite server for replication
![[together Shangshui Shuo series] day 7 content +day8](/img/fc/74b12addde3a4d3480e98f8578a969.png)
[together Shangshui Shuo series] day 7 content +day8

Redhat7系统root用户密码破解
![[function explanation (Part 2)] | [function declaration and definition + function recursion] key analysis + code diagram](/img/29/1644588927226a49d4b8815d8bc196.jpg)
[function explanation (Part 2)] | [function declaration and definition + function recursion] key analysis + code diagram

最大似然估计,散度,交叉熵
![[teacher Zhao Yuqiang] Cassandra foundation of NoSQL database](/img/cc/5509b62756dddc6e5d4facbc6a7c5f.jpg)
[teacher Zhao Yuqiang] Cassandra foundation of NoSQL database

mapbox尝鲜值之云图动画

Latest version of source insight
随机推荐
Mapbox tasting value cloud animation
中职网络子网划分例题解析
Redis encountered noauth authentication required
pytorch 多分类中的损失函数
Final review (Day7)
卷积神经网络CNN中的卷积操作详解
Exception when introducing redistemplate: noclassdeffounderror: com/fasterxml/jackson/core/jsonprocessingexception
为什么网站打开速度慢?
[teacher Zhao Yuqiang] redis's slow query log
How to use source insight
1. 两数之和
"C and pointer" - Chapter 13 function pointer 1: callback function 2 (combined with template to simplify code)
[branch and cycle] | | super long detailed explanation + code analysis + a trick game
Method of finding prime number
期末复习(DAY6)
The request database reported an error: "could not extract resultset; SQL [n/a]; needed exception is org.hibernate.exception.sqlgram"
ES 2022 正式发布!有哪些新特性?
Export the altaro event log to a text file
Es 2022 officially released! What are the new features?
2022.7.2 simulation match