当前位置:网站首页>pytorch 搭建神经网络最简版
pytorch 搭建神经网络最简版
2022-07-03 05:45:00 【code bean】
数据集为,糖尿病患者各项指标以及是否换糖尿病。

diabetes.csv 训练集
diabetes_test.csv 测试集
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import numpy as np
# 注意这里必须写成两维的矩阵
xy = np.loadtxt('diabetes.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, [-1]])
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 2)
self.linear4 = torch.nn.Linear(2, 1)
self.sigmoid = torch.nn.Sigmoid()
self.activate = torch.nn.ReLU()
# __call__() 中会调用这个函数!
def forward(self, x):
# x = self.activate(self.linear1(x))
# x = self.activate(self.linear2(x))
# x = self.activate(self.linear3(x))
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
x = self.sigmoid(self.linear4(x))
return x
# model为可调用的! 实现了 __call__()
model = Model()
# 指定损失函数
# criterion = torch.nn.MSELoss(size_average=Flase) # True
# criterion = torch.nn.MSELoss(reduction='sum') # sum:求和 mean:求平均
criterion = torch.nn.BCELoss(size_average=True) # 二分类交叉熵损失函数
# -- 指定优化器(其实就是有关梯度下降的算法,负责),这里将优化器和model进行了关联
# optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 这个用这个很准啊,其他得根本不行啊
# optimizer = torch.optim.Rprop(model.parameters(), lr=0.01)
for epoch in range(5000):
y_pred = model(x_data) # 直接把整个测试数据都放入了
loss = criterion(y_pred, y_data)
print(epoch, loss.item())
optimizer.zero_grad() # 会自动找到所有的w和b进行清零!优化器的作用 (为啥这个放到loss.backward()后面清零就不行了呢?)
loss.backward()
optimizer.step() # 会自动找到所有的w和b进行更新,优化器的作用!
# 测试
xy = np.loadtxt('diabetes_test.csv', delimiter=',', dtype=np.float32)
x_data = torch.from_numpy(xy[:, :-1])
y_data = torch.from_numpy(xy[:, -1])
x_test = torch.Tensor(x_data)
y_test = model(x_test) # 预测
print('y_pred = ', y_test.data)
# 对比预测结果和真实结果
for index, i in enumerate(y_test.data.numpy()):
if i[0] > 0.5:
print(1, int(y_data[index].item()))
else:
print(0, int(y_data[index].item()))
从代码看出,从一开始数据特征维度为8维,通过一层层的降维最终输出一维数据(输入虽然是多维,但是输出是一维一个概率值,仍然是个二分类问题)
如: Linear(8, 6) 表示输入特征是8维,输出特征6维。此时考虑的维度是列(未知数个数)。不考虑行的维度的,因为行的维度表示的是样本的个数,而对于pythroch的函数,样本个数是无所谓的,一个也行,一堆也罢,它只关心列的维度。
最终代码构建的模型如下:(我代码里是多构建了一层,大家也可以多尝试)

这里注意,虽然这里叫Linear Layer,但是其实,它里面是套了一个激活函数的,所以其实是非线性的,如果是线性的也就不叫神经网络

其他的激活函数也可以多试试:

欢迎大家将自己的设置发到评论区。看看大家预测的效果如何。
目前还没有对数据进行小批量处理,后续加上。
参考资料:
数据集资源:
边栏推荐
- Apt update and apt upgrade commands - what is the difference?
- @Import annotation: four ways to import configuration classes & source code analysis
- Complete set of C language file operation functions (super detailed)
- Capacity expansion mechanism of map
- How to use source insight
- [branch and cycle] | | super long detailed explanation + code analysis + a trick game
- Using the ethtool command by example
- [set theory] relational closure (reflexive closure | symmetric closure | transitive closure)
- chromedriver对应版本下载
- NG Textarea-auto-resize
猜你喜欢

Understand one-way hash function
![[teacher Zhao Yuqiang] redis's slow query log](/img/a7/2140744ebad9f1dc0a609254cc618e.jpg)
[teacher Zhao Yuqiang] redis's slow query log
![[together Shangshui Shuo series] day 7 content +day8](/img/fc/74b12addde3a4d3480e98f8578a969.png)
[together Shangshui Shuo series] day 7 content +day8

Brief introduction of realsense d435i imaging principle
![[teacher Zhao Yuqiang] index in mongodb (Part 2)](/img/a7/2140744ebad9f1dc0a609254cc618e.jpg)
[teacher Zhao Yuqiang] index in mongodb (Part 2)
![[trivia of two-dimensional array application] | [simple version] [detailed steps + code]](/img/84/98c1220d0f7bc3a948125ead6ff3d9.jpg)
[trivia of two-dimensional array application] | [simple version] [detailed steps + code]

Qt读写Excel--QXlsx插入图表5

Life is a process of continuous learning

Exception when introducing redistemplate: noclassdeffounderror: com/fasterxml/jackson/core/jsonprocessingexception

Training method of grasping angle in grasping detection
随机推荐
Notepad++ wrap by specified character
How to set up altaro offsite server for replication
Pytorch through load_ state_ Dict load weight
[teacher Zhao Yuqiang] index in mongodb (Part 2)
Final review (Day2)
牛客网 JS 分隔符
Redhat7 system root user password cracking
2022.7.2 simulation match
NG Textarea-auto-resize
[untitled]
chromedriver对应版本下载
今天很多 CTO 都是被干掉的,因为他没有成就业务
Jetson AgX Orin platform porting ar0233 gw5200 max9295 camera driver
ES 2022 正式发布!有哪些新特性?
Exception when introducing redistemplate: noclassdeffounderror: com/fasterxml/jackson/core/jsonprocessingexception
【一起上水硕系列】Day 10
Best practices for setting up altaro VM backups
How do I migrate my altaro VM backup configuration to another machine?
Gan network thought
Personal outlook | looking forward to the future from Xiaobai's self analysis and future planning
https://www.bilibili.com/video/BV1Y7411d7Ys?p=8&vd_source=eb730c561c03cdaf7ce5f40354ca252c