当前位置:网站首页>Pytorch builds the simplest version of neural network
Pytorch builds the simplest version of neural network
2022-07-03 05:49:00 【code bean】
The data set is , Indicators of patients with diabetes and whether to change diabetes .

diabetes.csv Training set
diabetes_test.csv Test set
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
# Note that this must be written as a two-dimensional matrix
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 2)
self.linear4 = torch.nn.Linear(2, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() This function will be called in !
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
x = self.sigmoid(self.linear4(x))
return x
# model Is callable ! Realized __call__()
model = Model()
# Specify the loss function
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum: Sum up mean: Averaging
criterion = torch.nn.BCELoss(size_average=True) # Two class cross entropy loss function
# -- Specify optimizer ( In fact, it is the algorithm of gradient descent , be responsible for ), The optimizer and model Associated
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # This one is very accurate , Other things are not good at all
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
for epoch in range(5000):
y_pred = model(x_data) # Directly put the whole test data into
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # It will automatically find all the w and b Clear ! The role of the optimizer ( Why is this put in loss.backward() You can't clear it later ?)
loss.backward()
optimizer.step() # It will automatically find all the w and b updated , The role of the optimizer !
# test
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # forecast
print('y_pred = ', y_test.data)
# Compare the predicted results with the real results
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
From the code , From the beginning, the data feature dimension is 8 dimension , Through the dimensionality reduction layer by layer, the final output of one-dimensional data ( Although the input is multidimensional , But the output is a one-dimensional probability value , It is still a binary problem )
Such as : Linear(8, 6) Indicates that the input characteristic is 8 dimension , Output characteristics 6 dimension . The dimension considered at this time is column ( The number of unknowns ). Regardless of the dimension of the row , Because the dimension of a row represents the number of samples , And for pythroch Function of , The number of samples doesn't matter , One will do , A pile of it , It only cares about the dimension of the column .
The model built by the final code is as follows :( I built an extra layer in my code , You can also try more )

Note here , Although it's called Linear Layer, But in fact , It contains an activation function , So it's actually nonlinear , If it is linear, it is not called neural network

You can also try other activation functions :

Welcome to send your settings to the comment area . Let's see the effect of your prediction .
At present, there is no small batch processing of data , Then add .
Reference material :
Dataset resources :
边栏推荐
- How to install and configure altaro VM backup for VMware vSphere
- Es 2022 officially released! What are the new features?
- Source insight operation manual installation trial
- 2022.7.2day594
- Final review (day3)
- CAD插件的安裝和自動加載dll、arx
- Shanghai daoning, together with American /n software, will provide you with more powerful Internet enterprise communication and security component services
- 伯努利分布,二项分布和泊松分布以及最大似然之间的关系(未完成)
- [teacher Zhao Yuqiang] use Oracle's tracking file
- redis 无法远程连接问题。
猜你喜欢

"C and pointer" - Chapter 13 function pointer 1: callback function 2 (combined with template to simplify code)

一起上水碩系列】Day 9

Beaucoup de CTO ont été tués aujourd'hui parce qu'il n'a pas fait d'affaires
![[teacher Zhao Yuqiang] use the catalog database of Oracle](/img/0b/73a7d12caf955dff17480a907234ad.jpg)
[teacher Zhao Yuqiang] use the catalog database of Oracle

理解 YOLOV1 第一篇 预测阶段
![[minesweeping of two-dimensional array application] | [simple version] [detailed steps + code]](/img/b0/aa5dce0bb60c50eea907de9e127d6c.jpg)
[minesweeping of two-dimensional array application] | [simple version] [detailed steps + code]

Error 1045 (28000) occurs when Linux logs in MySQL: access denied for user 'root' @ 'localhost' (using password: yes)

今天很多 CTO 都是被干掉的,因为他没有成就业务

为什么网站打开速度慢?

redis 无法远程连接问题。
随机推荐
[Zhao Yuqiang] deploy kubernetes cluster with binary package
[minesweeping of two-dimensional array application] | [simple version] [detailed steps + code]
Solve the 1251 client does not support authentication protocol error of Navicat for MySQL connection MySQL 8.0.11
[teacher Zhao Yuqiang] index in mongodb (Part 2)
How does win7 solve the problem that telnet is not an internal or external command
How do I migrate my altaro VM backup configuration to another machine?
Redis cannot connect remotely.
redis 无法远程连接问题。
一起上水碩系列】Day 9
NG Textarea-auto-resize
Altaro o365 total backup subscription plan
70 shell script interview questions and answers
Source insight operation manual installation trial
[video of Teacher Zhao Yuqiang's speech on wot] redis high performance cache and persistence
[Shangshui Shuo series together] day 10
Solve the problem of automatic disconnection of SecureCRT timeout connection
CAD插件的安裝和自動加載dll、arx
How to install and configure altaro VM backup for VMware vSphere
Xaml gradient issue in uwp for some devices
"C and pointer" - Chapter 13 function of function pointer 1 - callback function 1
https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&vd_source=eb730c561c03cdaf7ce5f40354ca252c