当前位置:网站首页>【深度学习】:《PyTorch入门到项目实战》第四天:从0到1实现logistic回归(附源码)
【深度学习】:《PyTorch入门到项目实战》第四天:从0到1实现logistic回归(附源码)
2022-07-28 16:02:00 【JoJo的数据分析历险记】
【深度学习】:《PyTorch入门到项目实战》从0到1实现logistic回归
- 本文收录于【深度学习】:《PyTorch入门到项目实战》专栏,此专栏主要记录如何使用
PyTorch实现深度学习笔记,尽量坚持每周持续更新,欢迎大家订阅! - 个人主页:JoJo的数据分析历险记
- 个人介绍:小编大四统计在读,目前保研到统计学top3高校继续攻读统计研究生
- 如果文章对你有帮助,欢迎
关注、点赞、收藏、订阅专栏
参考资料:本专栏主要以沐神《动手学深度学习》为学习资料,记录自己的学习笔记,能力有限,如有错误,欢迎大家指正。同时沐神上传了的教学视频和教材,大家可以前往学习。

文章目录
logistic回归虽然名字是回归,但实际上是一个分类算法,主要处理二分类问题,具体理论部分大家可以看我的这篇文章。机器学习算法:分类算法详解
基本模型如下:
H ( x ) = 1 1 + e − W T X H(x) = \frac{1}{1+e^{-W^TX}} H(x)=1+e−WTX1
损失函数:
c o s t ( W ) = − 1 m ∑ y l o g ( H ( x ) ) + ( 1 − y ) ( l o g ( 1 − H ( x ) ) cost(W) = -\frac{1}{m}\sum ylog(H(x))+(1-y)(log(1-H(x)) cost(W)=−m1∑ylog(H(x))+(1−y)(log(1−H(x))
其中y=1或0.从损失函数可以看出,如果 y y y和 H ( x ) H(x) H(x)很接近,则损失函数越小,下面我们来看看如何使用pytorch实现logistic回归
1.创建数据集
import torch
import torch.nn as nn
import torch.nn.functional as F#神经网络内置函数
import torch.optim as optim
# 设计随机数,为了结果的可复现性
torch.manual_seed(1)
<torch._C.Generator at 0x23225f75f78>
x = torch.tensor([[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]],dtype=torch.float)
y = torch.tensor([[0], [0], [0], [1], [1], [1]],dtype=torch.float)
考虑以下分类问题:给定每个学生观看讲座和在代码实验室工作的小时数,预测学生是否通过了课程。例如,第一个(索引0)学生看了一个小时的讲座,在实验课上花了两个小时([1,2]),结果课程不及格([0])。
2.初始化参数
和之前一样,我们初始化两个参数
W = torch.zeros((2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
3.计算模型
h = 1 / (1 + torch.exp(-(torch.matmul(x,W) + b)))
tensor([[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000],
[0.5000]], grad_fn=<SigmoidBackward0>)
在torch中,我们也可以使用torch.sigmoid()函数得到一样的结果
h = torch.sigmoid(torch.matmul(x,W)+b)
4.定义损失函数
def loss_fun(y,h):
return (-(y * torch.log(h) +
(1 - y) * torch.log(1 - h))).mean()
在nn中,包含许多内置函数,其中包含了计算交叉熵函数F.binary_cross_entropy,可以实现与上述代码一样的结果
F.binary_cross_entropy(h, y)
tensor(0.6931, grad_fn=<BinaryCrossEntropyBackward0>)
5.梯度下降求解
optim中包含了常见的优化算法,包括Adam,SGD等,这里我们还是和之前一样使用随机梯度下降,后续会介绍其他的优化算法
optimizer = optim.SGD([W, b], lr=0.05)
6.模型训练完整代码
''' 生成数据集 '''
x = torch.tensor([[1, 2], [2, 3], [3, 1], [4, 3], [5, 3], [6, 2]],dtype=torch.float)
y = torch.tensor([[0], [0], [0], [1], [1], [1]],dtype=torch.float)
''' 初始化参数 '''
W = torch.zeros((2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
''' 训练模型 '''
optimizer = optim.SGD([W, b], lr=0.5)
nb_epochs = 1000
for epoch in range(nb_epochs + 1):
#计算h
h = torch.sigmoid(x.matmul(W) + b)
#计算损失函数
cost = -(y * torch.log(h) +
(1 - y) * torch.log(1 - h)).mean()
# 梯度下降优化
optimizer.zero_grad()
cost.backward()
optimizer.step()
if epoch % 100 == 0:
print('Epoch {:4d}/{} Cost: {:.6f}'.format(
epoch, nb_epochs, cost.item()
))
Epoch 0/1000 Cost: 0.693147
Epoch 100/1000 Cost: 0.232941
Epoch 200/1000 Cost: 0.147042
Epoch 300/1000 Cost: 0.107431
Epoch 400/1000 Cost: 0.084848
Epoch 500/1000 Cost: 0.070247
Epoch 600/1000 Cost: 0.060012
Epoch 700/1000 Cost: 0.052428
Epoch 800/1000 Cost: 0.046575
Epoch 900/1000 Cost: 0.041916
Epoch 1000/1000 Cost: 0.038117
7.评估模型
在我们完成模型的训练后,我们想检查我们的模型是否适合训练集。
# 首先根据估计的参数结果计算h
h = torch.sigmoid(x.matmul(W) + b)
print(h)
tensor([[0.0033],
[0.0791],
[0.1106],
[0.8929],
[0.9880],
[0.9968]], grad_fn=<SigmoidBackward0>)
# 大于0.5的为True,小于0.5的为False
prediction = h >= torch.FloatTensor([0.5])
print(prediction)
tensor([[False],
[False],
[False],
[ True],
[ True],
[ True]])
# 注意在python中0=False,1=True
print(prediction)
print(y)
tensor([[False],
[False],
[False],
[ True],
[ True],
[ True]])
tensor([[0.],
[0.],
[0.],
[1.],
[1.],
[1.]])
# 计算预测值和真实值相同的个数
correct_prediction = prediction.float() == y
print(correct_prediction)
tensor([[True],
[True],
[True],
[True],
[True],
[True]])
# 计算预测正确的数量占总数量的比例
accuracy = correct_prediction.sum().item() / len(correct_prediction)
print('The model has an accuracy of {:2.2f}% for the training set.'.format(accuracy * 100))
The model has an accuracy of 100.00% for the training set.
🥤8.使用nn.Module实现logistic回归
上面为了演示logistic回归的具体实现原理,我们是使用一步一步实现的,但是在实际中,往往会使用nn.module或者nn实现,下面是实现logistic的简化代码。
''' 定义二元分类器 '''
class BinaryClassifier(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(2, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
return self.sigmoid(self.linear(x))
model = BinaryClassifier()
''' 定义随机梯度下降 '''
optimizer = optim.SGD(model.parameters(), lr=0.7)
''' 模型训练 '''
nb_epochs = 100
for epoch in range(nb_epochs + 1):
#计算h
hypothesis = model(x)
# 计算损失函数
cost = F.binary_cross_entropy(hypothesis, y)
# 梯度下降
optimizer.zero_grad()
cost.backward()
optimizer.step()
# 输出结果
if epoch % 10 == 0:
prediction = hypothesis >= torch.FloatTensor([0.5])
correct_prediction = prediction.float() == y
accuracy = correct_prediction.sum().item() / len(correct_prediction)
print('Epoch {:4d}/{} Cost: {:.6f} Accuracy {:2.2f}%'.format(
epoch, nb_epochs, cost.item(), accuracy * 100,
))
Epoch 0/100 Cost: 0.734527 Accuracy 50.00%
Epoch 10/100 Cost: 0.446570 Accuracy 66.67%
Epoch 20/100 Cost: 0.448868 Accuracy 66.67%
Epoch 30/100 Cost: 0.375859 Accuracy 83.33%
Epoch 40/100 Cost: 0.318583 Accuracy 83.33%
Epoch 50/100 Cost: 0.268096 Accuracy 83.33%
Epoch 60/100 Cost: 0.222295 Accuracy 100.00%
Epoch 70/100 Cost: 0.183465 Accuracy 100.00%
Epoch 80/100 Cost: 0.158036 Accuracy 100.00%
Epoch 90/100 Cost: 0.144541 Accuracy 100.00%
Epoch 100/100 Cost: 0.134652 Accuracy 100.00%
本章的介绍到此介绍,如果文章对你有帮助,请多多点赞、收藏、评论、关注支持!!
边栏推荐
- Analysis of echo service model in the first six chapters of unp
- LeetCode-学会对无序链表进行插入排序(详解)
- IM即时通讯开发优化提升连接成功率、速度等
- HM二次开发 - Data Names及其使用
- USB产品(FX3、CCG3PA)的调试方法
- 排序2-冒泡排序与快速排序(递归加非递归讲解)
- egg(十九):使用egg-redis性能优化,缓存数据提升响应效率
- HM secondary development - data names and its use
- 有趣的 Kotlin 0x09:Extensions are resolved statically
- Multiple commands produce ‘.../xxx.app/Assets.car‘问题
猜你喜欢

Learn to use MySQL explain to execute the plan, and SQL performance tuning is no longer difficult

Several methods of HyperMesh running script files

遭MQ连连干翻后的醒悟!含恨码出这份MQ手册助力秋招之旅

排序5-计数排序

有趣的 Kotlin 0x08:What am I

大学生参加六星教育PHP培训,找到了薪水远超同龄人的工作

排序3-选择排序与归并排序(递归实现+非递归实现)

Microsoft: edge browser has built-in disk cache compression technology, which can save space and not reduce system performance

在vs code上配置Hypermesh二次开发环境

nowcode-学会删除链表中重复元素两题(详解)
随机推荐
【JS】1394- ES2022 的 8 个实用的新功能
MD5加密验证
Splash (rendering JS service) introduction installation
Microsoft: edge browser has built-in disk cache compression technology, which can save space and not reduce system performance
epoll水平出发何边沿触发
IM即时通讯软件开发网络请求成功率的优化
Ansa secondary development - Introduction to interface development tools
Debugging methods of USB products (fx3, ccg3pa)
有趣的 Kotlin 0x06:List minus list
Wake up after being repeatedly upset by MQ! Hate code out this MQ manual to help the journey of autumn recruitment
Simple addition, deletion, modification and query of commodity information
LwIP develops | socket | TCP | keepalive heartbeat mechanism
“蔚来杯“2022牛客暑期多校训练营3 J.Journey 0-1最短路
Ansa secondary development - apps and ansa plug-in management
Applet: get element node information
智慧园区是未来发展的趋势吗?
Hdu1847 problem solving ideas
关于Bug处理的一些看法
ticdc同步数据怎么设置只同步指定的库?
Oracle system composition