当前位置:网站首页>Pytorch deep learning practice lesson 8 importing data
Pytorch deep learning practice lesson 8 importing data
2022-07-25 03:29:00 【falldeep】
b Station Liu Er video , Address :
https://www.bilibili.com/video/BV1Y7411d7Ys?p=9&vd_source=79d752a233297190ff0b01ca81ccd878
Code ( Homework in class )
Or the binary classification of diabetes in last class , Four step construction
import torch
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
#-------------------------------------------step1 prepare data----------------------------------------
class Data(Dataset): # Construct your own class , Inherited from Dataset class
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
# File path The separator in the data data type
self.len = xy.shape[0] # Return several rows and columns ( matrix )
self.x_data = torch.from_numpy(xy[:, :-1])# Take each line , Get to the last column
self.y_data = torch.from_numpy(xy[:, [-1]])# Take each line , Take the last column
def __getitem__(self, item): # Get a line of elements
return self.x_data[item], self.y_data[item]
def __len__(self):
return self.len
dataset = Data('diabetes.csv')
dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
#------------------------------------------setp2 design model--------------------------------------------
class Modle(torch.nn.Module):
def __init__(self):
super(Modle, self).__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
self.sigmoid = torch.nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x
model = Modle()
#---------------------------------------step3 constuct loss and optimizer-----------------------------
criteration = torch.nn.BCELoss(reduction='mean')#mean Calculating mean
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001) # To update is model Parameters in , Learning rate
#---------------------------------------step4 traning cycle------------------------------------------
if __name__ == '__main__': # I want to write this
loss_lst = []
for epoch in range(1000):# Outer layer epoch It is a training that all data sets have been traversed
sum = 0#
for i, data in enumerate(dataloader, 0):# One at a time batch, From a part of the whole data set
inputs, lables = data#x y
y_pred = model(inputs)
loss = criteration(y_pred, lables)
sum += loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_lst.append(sum / dataloader.batch_size)
# visualization
num_lst = [i for i in range(len(loss_lst))]
plt.plot(num_lst, loss_lst)
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()
Running results
MINIST Data import
边栏推荐
- Openlayers draw deletes the last point when drawing
- C language_ Defining structures and using variables
- Secondary vocational network security skills competition P100 web penetration test
- Use and introduction of vim file editor
- Direct insert sort / Hill sort
- mysql_ Record the executed SQL
- Take a database statement note: when the field is empty, assign the default value to the result
- Take a note: Oracle conditional statement
- Query the information of students whose grades are above 80
- Reasons for not sending requests after uni app packaging
猜你喜欢

Stm32cubemx quadrature encoder

Use and introduction of vim file editor

Machine learning exercise 8 - anomaly detection and recommendation system (collaborative filtering)

Flink1.15 source code reading - Flink annotations

B. Making Towers

Analysis of DNS domain name resolution process

B. Almost Ternary Matrix

2022-07-19 study notes of group 5 self-cultivation class (every day)

Bgy development small example

Brief understanding of operational amplifier
随机推荐
The difference between abstract classes and interfaces
Stm32cubemx quadrature encoder
JS password combination rule - 8-16 digit combination of numbers and characters, not pure numbers and pure English
基于SSM实现后勤报修系统
Calculation method of confusion matrix
B. Difference of GCDs
Direct insert sort / Hill sort
Function of each layer of data warehouse
05 - MVVM model
Moveit2 - 7. Scenario planning ROS API
JS common interview questions
Many local and municipal supervision departments carried out cold drink sampling inspection, and Zhong Xue's high-quality products were all qualified
C language_ Structure introduction
144. Preorder traversal of binary tree
B. Almost Ternary Matrix
How chemical enterprises choose digital service providers with dual prevention mechanism
mysql_ Backup restore_ Specify table_ Backup table_ Restore table_ innobackup
Easyexcel sets the style of the last row [which can be expanded to each row]
Chrome debugging skills
Secondary vocational network security skills competition P100 vulnerability detection
