当前位置:网站首页>Pytorch builds the simplest version of neural network
Pytorch builds the simplest version of neural network
2022-07-03 05:49:00 【code bean】
The data set is , Indicators of patients with diabetes and whether to change diabetes .
diabetes.csv Training set
diabetes_test.csv Test set
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
# Note that this must be written as a two-dimensional matrix
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 2)
self.linear4 = torch.nn.Linear(2, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() This function will be called in !
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
x = self.sigmoid(self.linear4(x))
return x
# model Is callable ! Realized __call__()
model = Model()
# Specify the loss function
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum: Sum up mean: Averaging
criterion = torch.nn.BCELoss(size_average=True) # Two class cross entropy loss function
# -- Specify optimizer ( In fact, it is the algorithm of gradient descent , be responsible for ), The optimizer and model Associated
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # This one is very accurate , Other things are not good at all
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
for epoch in range(5000):
y_pred = model(x_data) # Directly put the whole test data into
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # It will automatically find all the w and b Clear ! The role of the optimizer ( Why is this put in loss.backward() You can't clear it later ?)
loss.backward()
optimizer.step() # It will automatically find all the w and b updated , The role of the optimizer !
# test
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # forecast
print('y_pred = ', y_test.data)
# Compare the predicted results with the real results
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
From the code , From the beginning, the data feature dimension is 8 dimension , Through the dimensionality reduction layer by layer, the final output of one-dimensional data ( Although the input is multidimensional , But the output is a one-dimensional probability value , It is still a binary problem )
Such as : Linear(8, 6) Indicates that the input characteristic is 8 dimension , Output characteristics 6 dimension . The dimension considered at this time is column ( The number of unknowns ). Regardless of the dimension of the row , Because the dimension of a row represents the number of samples , And for pythroch Function of , The number of samples doesn't matter , One will do , A pile of it , It only cares about the dimension of the column .
The model built by the final code is as follows :( I built an extra layer in my code , You can also try more )
Note here , Although it's called Linear Layer, But in fact , It contains an activation function , So it's actually nonlinear , If it is linear, it is not called neural network
You can also try other activation functions :
Welcome to send your settings to the comment area . Let's see the effect of your prediction .
At present, there is no small batch processing of data , Then add .
Reference material :
Dataset resources :
边栏推荐
- Notepad++ wrap by specified character
- Sorry, this user does not exist!
- 配置xml文件的dtd
- The programmer shell with a monthly salary of more than 10000 becomes a grammar skill for secondary school. Do you often use it!!!
- Today, many CTOs were killed because they didn't achieve business
- Get and monitor remote server logs
- pytorch 多分类中的损失函数
- The request database reported an error: "could not extract resultset; SQL [n/a]; needed exception is org.hibernate.exception.sqlgram"
- JS implements the problem of closing the current child window and refreshing the parent window
- Solve the problem of automatic disconnection of SecureCRT timeout connection
猜你喜欢
How to use source insight
深度学习,从一维特性输入到多维特征输入引发的思考
Linux登录MySQL出现ERROR 1045 (28000): Access denied for user ‘root‘@‘localhost‘ (using password: YES)
大二困局(复盘)
[teacher Zhao Yuqiang] index in mongodb (Part 2)
[teacher Zhao Yuqiang] Cassandra foundation of NoSQL database
[teacher Zhao Yuqiang] index in mongodb (Part 1)
mapbox尝鲜值之云图动画
Ensemble, série shuishu] jour 9
Exception when introducing redistemplate: noclassdeffounderror: com/fasterxml/jackson/core/jsonprocessingexception
随机推荐
[function explanation (Part 1)] | | knowledge sorting + code analysis + graphic interpretation
Yum is too slow to bear? That's because you didn't do it
[teacher Zhao Yuqiang] index in mongodb (Part 2)
kubernetes资源对象介绍及常用命令(五)-(ConfigMap)
伯努利分布,二项分布和泊松分布以及最大似然之间的关系(未完成)
Kubernetes resource object introduction and common commands (V) - (configmap)
期末复习(Day5)
Common exceptions when Jenkins is released (continuous update...)
[together Shangshui Shuo series] day 7 content +day8
Apple submitted the new MAC model to the regulatory database before the spring conference
Simpleitk learning notes
Ext4 vs XFS -- which file system should you use
Mapbox tasting value cloud animation
Source insight automatic installation and licensing
Complete set of C language file operation functions (super detailed)
【无标题】
[teacher Zhao Yuqiang] calculate aggregation using MapReduce in mongodb
Personal outlook | looking forward to the future from Xiaobai's self analysis and future planning
[teacher Zhao Yuqiang] the most detailed introduction to PostgreSQL architecture in history
Error 1045 (28000) occurs when Linux logs in MySQL: access denied for user 'root' @ 'localhost' (using password: yes)