当前位置:网站首页>Pytorch builds the simplest version of neural network
Pytorch builds the simplest version of neural network
2022-07-03 05:49:00 【code bean】
The data set is , Indicators of patients with diabetes and whether to change diabetes .

diabetes.csv Training set
diabetes_test.csv Test set
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
# Note that this must be written as a two-dimensional matrix
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 2)
self.linear4 = torch.nn.Linear(2, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() This function will be called in !
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
x = self.sigmoid(self.linear4(x))
return x
# model Is callable ! Realized __call__()
model = Model()
# Specify the loss function
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum: Sum up mean: Averaging
criterion = torch.nn.BCELoss(size_average=True) # Two class cross entropy loss function
# -- Specify optimizer ( In fact, it is the algorithm of gradient descent , be responsible for ), The optimizer and model Associated
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # This one is very accurate , Other things are not good at all
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
for epoch in range(5000):
y_pred = model(x_data) # Directly put the whole test data into
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # It will automatically find all the w and b Clear ! The role of the optimizer ( Why is this put in loss.backward() You can't clear it later ?)
loss.backward()
optimizer.step() # It will automatically find all the w and b updated , The role of the optimizer !
# test
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # forecast
print('y_pred = ', y_test.data)
# Compare the predicted results with the real results
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
From the code , From the beginning, the data feature dimension is 8 dimension , Through the dimensionality reduction layer by layer, the final output of one-dimensional data ( Although the input is multidimensional , But the output is a one-dimensional probability value , It is still a binary problem )
Such as : Linear(8, 6) Indicates that the input characteristic is 8 dimension , Output characteristics 6 dimension . The dimension considered at this time is column ( The number of unknowns ). Regardless of the dimension of the row , Because the dimension of a row represents the number of samples , And for pythroch Function of , The number of samples doesn't matter , One will do , A pile of it , It only cares about the dimension of the column .
The model built by the final code is as follows :( I built an extra layer in my code , You can also try more )

Note here , Although it's called Linear Layer, But in fact , It contains an activation function , So it's actually nonlinear , If it is linear, it is not called neural network

You can also try other activation functions :

Welcome to send your settings to the comment area . Let's see the effect of your prediction .
At present, there is no small batch processing of data , Then add .
Reference material :
Dataset resources :
边栏推荐
- QT read write excel -- qxlsx insert chart 5
- PHP notes are super detailed!!!
- @Import annotation: four ways to import configuration classes & source code analysis
- 70 shell script interview questions and answers
- 32GB Jetson Orin SOM 不能刷机问题排查
- Today, many CTOs were killed because they didn't achieve business
- pytorch 搭建神经网络最简版
- Final review (Day2)
- redis 遇到 NOAUTH Authentication required
- 伯努利分布,二项分布和泊松分布以及最大似然之间的关系(未完成)
猜你喜欢

Communication - how to be a good listener?
![Ensemble, série shuishu] jour 9](/img/39/c1ba1bac82b0ed110f36423263ffd0.png)
Ensemble, série shuishu] jour 9

How do I migrate my altaro VM backup configuration to another machine?
![[teacher Zhao Yuqiang] redis's slow query log](/img/a7/2140744ebad9f1dc0a609254cc618e.jpg)
[teacher Zhao Yuqiang] redis's slow query log

期末复习(Day5)

今天很多 CTO 都是被干掉的,因为他没有成就业务

为什么网站打开速度慢?

Method of finding prime number
![[together Shangshui Shuo series] day 7 content +day8](/img/fc/74b12addde3a4d3480e98f8578a969.png)
[together Shangshui Shuo series] day 7 content +day8

Capacity expansion mechanism of map
随机推荐
一起上水碩系列】Day 9
Use telnet to check whether the port corresponding to the IP is open
Common exceptions when Jenkins is released (continuous update...)
Final review (day3)
【无标题】
There is no one of the necessary magic skills PXE for old drivers to install!!!
[together Shangshui Shuo series] day 7 content +day8
[teacher Zhao Yuqiang] MySQL flashback
[escape character] [full of dry goods] super detailed explanation + code illustration!
mapbox尝鲜值之云图动画
理解 YOLOV1 第一篇 预测阶段
Altaro o365 total backup subscription plan
Es 2022 officially released! What are the new features?
2022.7.2day594
Installation du plug - in CAD et chargement automatique DLL, Arx
[set theory] relational closure (relational closure related theorem)
Communication - how to be a good listener?
C 语言文件操作函数大全 (超详细)
Apt update and apt upgrade commands - what is the difference?
The programmer shell with a monthly salary of more than 10000 becomes a grammar skill for secondary school. Do you often use it!!!
https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&vd_source=eb730c561c03cdaf7ce5f40354ca252c