当前位置:网站首页>基于Pytorch 框架手动完成线性回归
基于Pytorch 框架手动完成线性回归
2022-07-07 04:57:00 【不想秃头的学生】
Pytorch完成线性回归
hello,各位朋友好久不见,最近在忙着期末考试,现在结束之后继续更新咱们的Pytorch框架学习笔记
目标
- 知道
requires_grad的作用 - 知道如何使用
backward - 知道如何手动完成线性回归
1. 向前计算
对于pytorch中的一个tensor,如果设置它的属性 .requires_grad为True,那么它将会追踪对于该张量的所有操作。或者可以理解为,这个tensor是一个参数,后续会被计算梯度,更新该参数。
1.1 计算过程
假设有以下条件(1/4表示求均值,xi中有4个数),使用torch完成其向前计算的过程
KaTeX parse error: No such environment: align* at position 8: \begin{̲a̲l̲i̲g̲n̲*̲}̲ &o = \frac{1}{…
如果x为参数,需要对其进行梯度的计算和更新
那么,在最开始随机设置x的值的过程中,需要设置他的requires_grad属性为True,其默认值为False
import torch
x = torch.ones(2, 2, requires_grad=True) #初始化参数x并设置requires_grad=True用来追踪其计算历史
print(x)
#tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
y = x+2
print(y)
#tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
z = y*y*3 #平方x3
print(x)
#tensor([[27., 27.],
# [27., 27.]], grad_fn=<MulBackward0>)
out = z.mean() #求均值
print(out)
#tensor(27., grad_fn=<MeanBackward0>)
从上述代码可以看出:
- x的requires_grad属性为True
- 之后的每次计算都会修改其
grad_fn属性,用来记录做过的操作- 通过这个函数和grad_fn能够组成一个和前一小节类似的计算图
1.2 requires_grad和grad_fn
a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) #False
a.requires_grad_(True) #就地修改
print(a.requires_grad) #True
b = (a * a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x4e2b14345d21>
with torch.no_gard():
c = (a * a).sum() #tensor(151.6830),此时c没有gard_fn
print(c.requires_grad) #False
注意:
为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():中。在评估模型时特别有用,因为模型可能具有requires_grad = True的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。
2. 梯度计算
对于1.1 中的out而言,我们可以使用backward方法来进行反向传播,计算梯度
out.backward(),此时便能够求出导数 d o u t d x \frac{d out}{dx} dxdout,调用x.gard能够获取导数值
得到
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
因为:
d ( O ) d ( x i ) = 3 2 ( x i + 2 ) \frac{d(O)}{d(x_i)} = \frac{3}{2}(x_i+2) d(xi)d(O)=23(xi+2)
在 x i x_i xi等于1时其值为4.5
注意:在输出为一个标量的情况下,我们可以调用输出tensor的backword() 方法,但是在数据是一个向量的时候,调用backward()的时候还需要传入其他参数。
很多时候我们的损失函数都是一个标量,所以这里就不再介绍损失为向量的情况。
loss.backward()就是根据损失函数,对参数(requires_grad=True)的去计算他的梯度,并且把它累加保存到x.gard,此时还并未更新其梯度
注意点:
tensor.data:在tensor的require_grad=False,tensor.data和tensor等价
require_grad=True时,tensor.data仅仅是获取tensor中的数据
tensor.numpy():require_grad=True不能够直接转换,需要使用tensor.detach().numpy()
3. 线性回归实现
下面,我们使用一个自定义的数据,来使用torch实现一个简单的线性回归
假设我们的基础模型就是y = wx+b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8
- 准备数据
- 计算预测值
- 计算损失,把参数的梯度置为0,进行反向传播
- 更新参数
import torch
import numpy as np
from matplotlib import pyplot as plt
#1. 准备数据 y = 3x+0.8,准备参数
x = torch.rand([50])
y = 3*x + 0.8
w = torch.rand(1,requires_grad=True)
b = torch.rand(1,requires_grad=True)
def loss_fn(y,y_predict):
loss = (y_predict-y).pow(2).mean()
for i in [w,b]:
#每次反向传播前把梯度置为0
if i.grad is not None:
i.grad.data.zero_()
# [i.grad.data.zero_() for i in [w,b] if i.grad is not None]
loss.backward()
return loss.data
def optimize(learning_rate):
# print(w.grad.data,w.data,b.data)
w.data -= learning_rate* w.grad.data
b.data -= learning_rate* b.grad.data
for i in range(3000):
#2. 计算预测值
y_predict = x*w + b
#3.计算损失,把参数的梯度置为0,进行反向传播
loss = loss_fn(y,y_predict)
if i%500 == 0:
print(i,loss)
#4. 更新参数w和b
optimize(0.01)
# 绘制图形,观察训练结束的预测值和真实值
predict = x*w + b #使用训练后的w和b计算预测值
plt.scatter(x.data.numpy(), y.data.numpy(),c = "r")
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()
print("w",w)
print("b",b)
图形效果如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPoO4Eha-1656763156468)(…/images/1.2/线性回归1.png)]
打印w和b,可有
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
图形效果如下:
[外链图片转存中...(img-EPoO4Eha-1656763156468)]
打印w和b,可有
```python
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
边栏推荐
- LeetCode 40:组合总和 II
- 王爽 《汇编语言》之寄存器
- 2022 Inner Mongolia latest advanced fire facility operator simulation examination question bank and answers
- Info | webrtc M97 update
- CTF daily question day43 rsa5
- [quick start of Digital IC Verification] 17. Basic grammar of SystemVerilog learning 4 (randomization)
- B. Value sequence thinking
- JS quick start (I)
- 【数字IC验证快速入门】17、SystemVerilog学习之基本语法4(随机化Randomization)
- Minimum absolute difference of binary search tree (use medium order traversal as an ordered array)
猜你喜欢

2022 recurrent training question bank and answers of refrigeration and air conditioning equipment operation

buureservewp(2)

json 数据展平pd.json_normalize

Common validation comments
![[CV] Wu Enda machine learning course notes | Chapter 8](/img/c0/7a39355fb3a6cb506f0fbcf2a7aa24.jpg)
[CV] Wu Enda machine learning course notes | Chapter 8

2022茶艺师(初级)考试题模拟考试题库及在线模拟考试

Li Kou interview question 04.01 Path between nodes
![[quick start of Digital IC Verification] 17. Basic grammar of SystemVerilog learning 4 (randomization)](/img/39/cac2b5492d374da393569e2ab467a4.png)
[quick start of Digital IC Verification] 17. Basic grammar of SystemVerilog learning 4 (randomization)

LeetCode 90:子集 II

Linux server development, MySQL transaction principle analysis
随机推荐
【数字IC验证快速入门】14、SystemVerilog学习之基本语法1(数组、队列、结构体、枚举、字符串...内含实践练习)
2022茶艺师(初级)考试题模拟考试题库及在线模拟考试
Leetcode 43 String multiplication (2022.02.12)
【数字IC验证快速入门】11、Verilog TestBench(VTB)入门
[quick start of Digital IC Verification] 17. Basic grammar of SystemVerilog learning 4 (randomization)
Qt学习26 布局管理综合实例
Introduction to basic components of wechat applet
[UVM practice] Chapter 1: configuring the UVM environment (taking VCs as an example), run through the examples in the book
Who has docker to install MySQL locally?
Force buckle 145 Binary Tree Postorder Traversal
MySQL multi column index (composite index) features and usage scenarios
Leanote private cloud note building
These five fishing artifacts are too hot! Programmer: I know, delete it quickly!
Value sequence (subsequence contribution problem)
Sign up now | oar hacker marathon phase III, waiting for your challenge
【数字IC验证快速入门】12、SystemVerilog TestBench(SVTB)入门
UnityHub破解&Unity破解
[2022 ciscn] replay of preliminary web topics
有 Docker 谁还在自己本地安装 Mysql ?
Cnopendata American Golden Globe Award winning data