当前位置:网站首页>基于Pytorch 框架手动完成线性回归
基于Pytorch 框架手动完成线性回归
2022-07-07 04:57:00 【不想秃头的学生】
Pytorch完成线性回归
hello,各位朋友好久不见,最近在忙着期末考试,现在结束之后继续更新咱们的Pytorch框架学习笔记
目标
- 知道
requires_grad
的作用 - 知道如何使用
backward
- 知道如何手动完成线性回归
1. 向前计算
对于pytorch中的一个tensor,如果设置它的属性 .requires_grad
为True
,那么它将会追踪对于该张量的所有操作。或者可以理解为,这个tensor是一个参数,后续会被计算梯度,更新该参数。
1.1 计算过程
假设有以下条件(1/4表示求均值,xi中有4个数),使用torch完成其向前计算的过程
KaTeX parse error: No such environment: align* at position 8: \begin{̲a̲l̲i̲g̲n̲*̲}̲ &o = \frac{1}{…
如果x为参数,需要对其进行梯度的计算和更新
那么,在最开始随机设置x的值的过程中,需要设置他的requires_grad属性为True,其默认值为False
import torch
x = torch.ones(2, 2, requires_grad=True) #初始化参数x并设置requires_grad=True用来追踪其计算历史
print(x)
#tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
y = x+2
print(y)
#tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
z = y*y*3 #平方x3
print(x)
#tensor([[27., 27.],
# [27., 27.]], grad_fn=<MulBackward0>)
out = z.mean() #求均值
print(out)
#tensor(27., grad_fn=<MeanBackward0>)
从上述代码可以看出:
- x的requires_grad属性为True
- 之后的每次计算都会修改其
grad_fn
属性,用来记录做过的操作- 通过这个函数和grad_fn能够组成一个和前一小节类似的计算图
1.2 requires_grad和grad_fn
a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) #False
a.requires_grad_(True) #就地修改
print(a.requires_grad) #True
b = (a * a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x4e2b14345d21>
with torch.no_gard():
c = (a * a).sum() #tensor(151.6830),此时c没有gard_fn
print(c.requires_grad) #False
注意:
为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():
中。在评估模型时特别有用,因为模型可能具有requires_grad = True
的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。
2. 梯度计算
对于1.1 中的out而言,我们可以使用backward
方法来进行反向传播,计算梯度
out.backward()
,此时便能够求出导数 d o u t d x \frac{d out}{dx} dxdout,调用x.gard
能够获取导数值
得到
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
因为:
d ( O ) d ( x i ) = 3 2 ( x i + 2 ) \frac{d(O)}{d(x_i)} = \frac{3}{2}(x_i+2) d(xi)d(O)=23(xi+2)
在 x i x_i xi等于1时其值为4.5
注意:在输出为一个标量的情况下,我们可以调用输出tensor
的backword()
方法,但是在数据是一个向量的时候,调用backward()
的时候还需要传入其他参数。
很多时候我们的损失函数都是一个标量,所以这里就不再介绍损失为向量的情况。
loss.backward()
就是根据损失函数,对参数(requires_grad=True)的去计算他的梯度,并且把它累加保存到x.gard
,此时还并未更新其梯度
注意点:
tensor.data
:在tensor的require_grad=False,tensor.data和tensor等价
require_grad=True时,tensor.data仅仅是获取tensor中的数据
tensor.numpy()
:require_grad=True
不能够直接转换,需要使用tensor.detach().numpy()
3. 线性回归实现
下面,我们使用一个自定义的数据,来使用torch实现一个简单的线性回归
假设我们的基础模型就是y = wx+b
,其中w和b均为参数,我们使用y = 3x+0.8
来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8
- 准备数据
- 计算预测值
- 计算损失,把参数的梯度置为0,进行反向传播
- 更新参数
import torch
import numpy as np
from matplotlib import pyplot as plt
#1. 准备数据 y = 3x+0.8,准备参数
x = torch.rand([50])
y = 3*x + 0.8
w = torch.rand(1,requires_grad=True)
b = torch.rand(1,requires_grad=True)
def loss_fn(y,y_predict):
loss = (y_predict-y).pow(2).mean()
for i in [w,b]:
#每次反向传播前把梯度置为0
if i.grad is not None:
i.grad.data.zero_()
# [i.grad.data.zero_() for i in [w,b] if i.grad is not None]
loss.backward()
return loss.data
def optimize(learning_rate):
# print(w.grad.data,w.data,b.data)
w.data -= learning_rate* w.grad.data
b.data -= learning_rate* b.grad.data
for i in range(3000):
#2. 计算预测值
y_predict = x*w + b
#3.计算损失,把参数的梯度置为0,进行反向传播
loss = loss_fn(y,y_predict)
if i%500 == 0:
print(i,loss)
#4. 更新参数w和b
optimize(0.01)
# 绘制图形,观察训练结束的预测值和真实值
predict = x*w + b #使用训练后的w和b计算预测值
plt.scatter(x.data.numpy(), y.data.numpy(),c = "r")
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()
print("w",w)
print("b",b)
图形效果如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPoO4Eha-1656763156468)(…/images/1.2/线性回归1.png)]
打印w和b,可有
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
图形效果如下:
[外链图片转存中...(img-EPoO4Eha-1656763156468)]
打印w和b,可有
```python
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
边栏推荐
猜你喜欢
[2022 actf] Web Topic recurrence
2022年茶艺师(中级)考试试题及模拟考试
自定义类加载器加载网络Class
JSON data flattening pd json_ normalize
QT learning 26 integrated example of layout management
Figure out the working principle of gpt3
Thinkcmf6.0 installation tutorial
互动送书-《Oracle DBA工作笔记》签名版
These five fishing artifacts are too hot! Programmer: I know, delete it quickly!
Cnopendata list data of Chinese colleges and Universities
随机推荐
【数字IC验证快速入门】17、SystemVerilog学习之基本语法4(随机化Randomization)
Leanote private cloud note building
【踩坑系列】uniapp之h5 跨域的问题
LeetCode简单题之字符串中最大的 3 位相同数字
Quickly use Jacobo code coverage statistics
有 Docker 谁还在自己本地安装 Mysql ?
2022年全国最新消防设施操作员(初级消防设施操作员)模拟题及答案
【数字IC验证快速入门】10、Verilog RTL设计必会的FIFO
Bugku CTF daily one question chessboard with only black chess
LeetCode中等题之我的日程安排表 I
Linux server development, redis protocol and asynchronous mode
C语言通信行程卡后台系统
Linux server development, redis source code storage principle and data model
Es FAQ summary
【数字IC验证快速入门】15、SystemVerilog学习之基本语法2(操作符、类型转换、循环、Task/Function...内含实践练习)
【数字IC验证快速入门】12、SystemVerilog TestBench(SVTB)入门
Merging binary trees by recursion
央视太暖心了,手把手教你写HR最喜欢的简历
Few shot Learning & meta learning: small sample learning principle and Siamese network structure (I)
力扣(LeetCode)187. 重复的DNA序列(2022.07.06)