当前位置:网站首页>基于Pytorch 框架手动完成线性回归
基于Pytorch 框架手动完成线性回归
2022-07-07 04:57:00 【不想秃头的学生】
Pytorch完成线性回归
hello,各位朋友好久不见,最近在忙着期末考试,现在结束之后继续更新咱们的Pytorch框架学习笔记
目标
- 知道
requires_grad
的作用 - 知道如何使用
backward
- 知道如何手动完成线性回归
1. 向前计算
对于pytorch中的一个tensor,如果设置它的属性 .requires_grad
为True
,那么它将会追踪对于该张量的所有操作。或者可以理解为,这个tensor是一个参数,后续会被计算梯度,更新该参数。
1.1 计算过程
假设有以下条件(1/4表示求均值,xi中有4个数),使用torch完成其向前计算的过程
KaTeX parse error: No such environment: align* at position 8: \begin{̲a̲l̲i̲g̲n̲*̲}̲ &o = \frac{1}{…
如果x为参数,需要对其进行梯度的计算和更新
那么,在最开始随机设置x的值的过程中,需要设置他的requires_grad属性为True,其默认值为False
import torch
x = torch.ones(2, 2, requires_grad=True) #初始化参数x并设置requires_grad=True用来追踪其计算历史
print(x)
#tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
y = x+2
print(y)
#tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
z = y*y*3 #平方x3
print(x)
#tensor([[27., 27.],
# [27., 27.]], grad_fn=<MulBackward0>)
out = z.mean() #求均值
print(out)
#tensor(27., grad_fn=<MeanBackward0>)
从上述代码可以看出:
- x的requires_grad属性为True
- 之后的每次计算都会修改其
grad_fn
属性,用来记录做过的操作- 通过这个函数和grad_fn能够组成一个和前一小节类似的计算图
1.2 requires_grad和grad_fn
a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) #False
a.requires_grad_(True) #就地修改
print(a.requires_grad) #True
b = (a * a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x4e2b14345d21>
with torch.no_gard():
c = (a * a).sum() #tensor(151.6830),此时c没有gard_fn
print(c.requires_grad) #False
注意:
为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():
中。在评估模型时特别有用,因为模型可能具有requires_grad = True
的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。
2. 梯度计算
对于1.1 中的out而言,我们可以使用backward
方法来进行反向传播,计算梯度
out.backward()
,此时便能够求出导数 d o u t d x \frac{d out}{dx} dxdout,调用x.gard
能够获取导数值
得到
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
因为:
d ( O ) d ( x i ) = 3 2 ( x i + 2 ) \frac{d(O)}{d(x_i)} = \frac{3}{2}(x_i+2) d(xi)d(O)=23(xi+2)
在 x i x_i xi等于1时其值为4.5
注意:在输出为一个标量的情况下,我们可以调用输出tensor
的backword()
方法,但是在数据是一个向量的时候,调用backward()
的时候还需要传入其他参数。
很多时候我们的损失函数都是一个标量,所以这里就不再介绍损失为向量的情况。
loss.backward()
就是根据损失函数,对参数(requires_grad=True)的去计算他的梯度,并且把它累加保存到x.gard
,此时还并未更新其梯度
注意点:
tensor.data
:在tensor的require_grad=False,tensor.data和tensor等价
require_grad=True时,tensor.data仅仅是获取tensor中的数据
tensor.numpy()
:require_grad=True
不能够直接转换,需要使用tensor.detach().numpy()
3. 线性回归实现
下面,我们使用一个自定义的数据,来使用torch实现一个简单的线性回归
假设我们的基础模型就是y = wx+b
,其中w和b均为参数,我们使用y = 3x+0.8
来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8
- 准备数据
- 计算预测值
- 计算损失,把参数的梯度置为0,进行反向传播
- 更新参数
import torch
import numpy as np
from matplotlib import pyplot as plt
#1. 准备数据 y = 3x+0.8,准备参数
x = torch.rand([50])
y = 3*x + 0.8
w = torch.rand(1,requires_grad=True)
b = torch.rand(1,requires_grad=True)
def loss_fn(y,y_predict):
loss = (y_predict-y).pow(2).mean()
for i in [w,b]:
#每次反向传播前把梯度置为0
if i.grad is not None:
i.grad.data.zero_()
# [i.grad.data.zero_() for i in [w,b] if i.grad is not None]
loss.backward()
return loss.data
def optimize(learning_rate):
# print(w.grad.data,w.data,b.data)
w.data -= learning_rate* w.grad.data
b.data -= learning_rate* b.grad.data
for i in range(3000):
#2. 计算预测值
y_predict = x*w + b
#3.计算损失,把参数的梯度置为0,进行反向传播
loss = loss_fn(y,y_predict)
if i%500 == 0:
print(i,loss)
#4. 更新参数w和b
optimize(0.01)
# 绘制图形,观察训练结束的预测值和真实值
predict = x*w + b #使用训练后的w和b计算预测值
plt.scatter(x.data.numpy(), y.data.numpy(),c = "r")
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()
print("w",w)
print("b",b)
图形效果如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPoO4Eha-1656763156468)(…/images/1.2/线性回归1.png)]
打印w和b,可有
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
图形效果如下:
[外链图片转存中...(img-EPoO4Eha-1656763156468)]
打印w和b,可有
```python
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
边栏推荐
- Recursive method constructs binary tree from middle order and post order traversal sequence
- Roulette chart 2 - writing of roulette chart code
- Redis technology leak detection and filling (II) - expired deletion strategy
- 追风赶月莫停留,平芜尽处是春山
- Network learning (I) -- basic model learning
- 2022 National latest fire-fighting facility operator (primary fire-fighting facility operator) simulation questions and answers
- Installing postgresql11 database under centos7
- Leetcode 40: combined sum II
- [UVM foundation] what is transaction
- C language communication travel card background system
猜你喜欢
2022 tea master (intermediate) examination questions and mock examination
game攻防世界逆向
Linux server development, detailed explanation of redis related commands and their principles
json 数据展平pd.json_normalize
Sign up now | oar hacker marathon phase III, waiting for your challenge
You Li takes you to talk about C language 6 (common keywords)
Use and analysis of dot function in numpy
Linux server development, redis protocol and asynchronous mode
【數字IC驗證快速入門】15、SystemVerilog學習之基本語法2(操作符、類型轉換、循環、Task/Function...內含實踐練習)
Ansible
随机推荐
Explore dry goods! Apifox construction ideas
LeetCode简单题之判断一个数的数字计数是否等于数位的值
Recursive method to verify whether a tree is a binary search tree (BST)
Visualization Document Feb 12 16:42
Linux server development, MySQL transaction principle analysis
Summary of redis functions
Custom class loader loads network class
王爽 《汇编语言》之寄存器
Thinkcmf6.0 installation tutorial
自定义类加载器加载网络Class
Info | webrtc M97 update
paddlepaddle 29 无模型定义代码下动态修改网络结构(relu变prelu,conv2d变conv3d,2d语义分割模型改为3d语义分割模型)
青龙面板--花花阅读
OpenJudge NOI 2.1 1752:鸡兔同笼
[advanced digital IC Verification] command query method and common command interpretation of VCs tool
LeetCode简单题之字符串中最大的 3 位相同数字
[UVM basics] summary of important knowledge points of "UVM practice" (continuous update...)
追风赶月莫停留,平芜尽处是春山
2022制冷与空调设备运行操作复训题库及答案
2022 Inner Mongolia latest advanced fire facility operator simulation examination question bank and answers