当前位置:网站首页>Pytorch dataloader implements minibatch (incomplete)
Pytorch dataloader implements minibatch (incomplete)
2022-07-03 05:48:00 【code bean】
Take the book back 《pytorch Build the simplest version of neural network 》 Not used last time miniBatch, Input all the data for training at one time .
This pass DataLoader Realization miniBatch
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]])
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return self.len
dataset = DiabetesDataset('diabetes.csv')
train_loader = DataLoader(dataset=dataset,
batch_size=749,
shuffle=False,
num_workers=1)
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() This function will be called in !
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
# model Is callable ! Realized __call__()
model = Model()
# Specify the loss function
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum: Sum up mean: Averaging
criterion = torch.nn.BCELoss(reduction='mean') # Two class cross entropy loss function
# -- Specify optimizer ( In fact, it is the algorithm of gradient descent , be responsible for ), The optimizer and model Associated
# optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # This one is very accurate , Other things are not good at all
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
# for epoch in range(5000):
# y_pred = model(x_data) # Directly put the whole test data into
# loss = criterion(y_pred, y_data)
# print(epoch, loss.item())
# optimizer.zero_grad() # It will automatically find all the w and b Clear ! The role of the optimizer ( Why is this put in loss.backward() You can't clear it later ?)
# loss.backward()
# optimizer.step() # It will automatically find all the w and b updated , The role of the optimizer !
#
if __name__ == '__main__':
for epoch in range(1000):
for i, data in enumerate(train_loader, 0):
# 1. Prepare data
inputs, labels = data
# 2. Forward
y_pred = model(inputs)
loss = criterion(y_pred, labels)
print(epoch, i, loss.item())
# 3. Backward
optimizer.zero_grad()
loss.backward()
# 4. Update
optimizer.step()
# test
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # forecast
print('y_pred = ', y_test.data)
# Compare the predicted results with the real results
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
Intercept key code :
train_loader = DataLoader(dataset=dataset,
batch_size=749,
shuffle=False,
num_workers=1)batch_size The size of each batch of training , If the total number of data pieces is 7490 strip , So at this time train_loader Will produce 10 individual batch.shuffle It's a shuffle , It is to disorder the order of data and then put the data into 10 individual batch in .

And now why do I , To put batch_size Set to 29, And set it not to shuffle . Because I want to compare with the previous process .
Because there are only 749 data , Then there will only be one batch, One batch Inside is all the data without disordering the order . This is the same as the previous training process .
I did this because I found , Used DataLoader After that, the training process became unusually slow , After changing to the above configuration , I thought the speed would be the same , But it's actually slow 10 More than !
What hurts more is , If batch_size Make it smaller ,loss There is no way to converge directly , If there are too many cyclic tests, an error is reported :
return _winapi.DuplicateHandle( PermissionError: [WinError 5] Access denied
this DataLoader How to use it ? Which God will give you some advice ?
边栏推荐
- 32GB Jetson Orin SOM 不能刷机问题排查
- Kubernetes resource object introduction and common commands (V) - (configmap)
- [together Shangshui Shuo series] day 7 content +day8
- Final review Day8
- Why should there be a firewall? This time xiaowai has something to say!!!
- @Solutions to null pointer error caused by Autowired
- EMD distance - example of use
- Together, Shangshui Shuo series] day 9
- 2022.7.2day594
- [advanced pointer (1)] | detailed explanation of character pointer, pointer array, array pointer
猜你喜欢
![[advanced pointer (1)] | detailed explanation of character pointer, pointer array, array pointer](/img/9e/a4558e8e53c9655cbc1a38e8c0536e.jpg)
[advanced pointer (1)] | detailed explanation of character pointer, pointer array, array pointer
![[teacher Zhao Yuqiang] MySQL flashback](/img/93/75998e28fd309880661ea723dc8de6.jpg)
[teacher Zhao Yuqiang] MySQL flashback

Export the altaro event log to a text file

Altaro virtual machine replication failed: "unsupported file type vmgs"

Apt update and apt upgrade commands - what is the difference?

Apache+php+mysql environment construction is super detailed!!!

理解 期望(均值/估计值)和方差

6.23 warehouse operation on Thursday

mapbox尝鲜值之云图动画

Linux登录MySQL出现ERROR 1045 (28000): Access denied for user ‘root‘@‘localhost‘ (using password: YES)
随机推荐
Sorry, this user does not exist!
[escape character] [full of dry goods] super detailed explanation + code illustration!
Btrfs and ext4 - features, strengths and weaknesses
JS implements the problem of closing the current child window and refreshing the parent window
Together, Shangshui Shuo series] day 9
Detailed explanation of iptables (1): iptables concept
2022.7.2 模拟赛
Final review (Day7)
CAD插件的安裝和自動加載dll、arx
深度学习,从一维特性输入到多维特征输入引发的思考
Export the altaro event log to a text file
[written examination question analysis] | | get [sizeof and strlen] [pointer and array] graphic explanation + code analysis
[Zhao Yuqiang] deploy kubernetes cluster with binary package
2022.6.30DAY591
ansible防火墙firewalld设置
[minesweeping of two-dimensional array application] | [simple version] [detailed steps + code]
2022.7.2 simulation match
If function of MySQL
CAD插件的安装和自动加载dll、arx
Strategy pattern: encapsulate changes and respond flexibly to changes in requirements