当前位置:网站首页>基于Pytorch 框架手动完成线性回归
基于Pytorch 框架手动完成线性回归
2022-07-07 04:57:00 【不想秃头的学生】
Pytorch完成线性回归
hello,各位朋友好久不见,最近在忙着期末考试,现在结束之后继续更新咱们的Pytorch框架学习笔记
目标
- 知道
requires_grad的作用 - 知道如何使用
backward - 知道如何手动完成线性回归
1. 向前计算
对于pytorch中的一个tensor,如果设置它的属性 .requires_grad为True,那么它将会追踪对于该张量的所有操作。或者可以理解为,这个tensor是一个参数,后续会被计算梯度,更新该参数。
1.1 计算过程
假设有以下条件(1/4表示求均值,xi中有4个数),使用torch完成其向前计算的过程
KaTeX parse error: No such environment: align* at position 8: \begin{̲a̲l̲i̲g̲n̲*̲}̲ &o = \frac{1}{…
如果x为参数,需要对其进行梯度的计算和更新
那么,在最开始随机设置x的值的过程中,需要设置他的requires_grad属性为True,其默认值为False
import torch
x = torch.ones(2, 2, requires_grad=True) #初始化参数x并设置requires_grad=True用来追踪其计算历史
print(x)
#tensor([[1., 1.],
# [1., 1.]], requires_grad=True)
y = x+2
print(y)
#tensor([[3., 3.],
# [3., 3.]], grad_fn=<AddBackward0>)
z = y*y*3 #平方x3
print(x)
#tensor([[27., 27.],
# [27., 27.]], grad_fn=<MulBackward0>)
out = z.mean() #求均值
print(out)
#tensor(27., grad_fn=<MeanBackward0>)
从上述代码可以看出:
- x的requires_grad属性为True
- 之后的每次计算都会修改其
grad_fn属性,用来记录做过的操作- 通过这个函数和grad_fn能够组成一个和前一小节类似的计算图
1.2 requires_grad和grad_fn
a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) #False
a.requires_grad_(True) #就地修改
print(a.requires_grad) #True
b = (a * a).sum()
print(b.grad_fn) # <SumBackward0 object at 0x4e2b14345d21>
with torch.no_gard():
c = (a * a).sum() #tensor(151.6830),此时c没有gard_fn
print(c.requires_grad) #False
注意:
为了防止跟踪历史记录(和使用内存),可以将代码块包装在with torch.no_grad():中。在评估模型时特别有用,因为模型可能具有requires_grad = True的可训练的参数,但是我们不需要在此过程中对他们进行梯度计算。
2. 梯度计算
对于1.1 中的out而言,我们可以使用backward方法来进行反向传播,计算梯度
out.backward(),此时便能够求出导数 d o u t d x \frac{d out}{dx} dxdout,调用x.gard能够获取导数值
得到
tensor([[4.5000, 4.5000],
[4.5000, 4.5000]])
因为:
d ( O ) d ( x i ) = 3 2 ( x i + 2 ) \frac{d(O)}{d(x_i)} = \frac{3}{2}(x_i+2) d(xi)d(O)=23(xi+2)
在 x i x_i xi等于1时其值为4.5
注意:在输出为一个标量的情况下,我们可以调用输出tensor的backword() 方法,但是在数据是一个向量的时候,调用backward()的时候还需要传入其他参数。
很多时候我们的损失函数都是一个标量,所以这里就不再介绍损失为向量的情况。
loss.backward()就是根据损失函数,对参数(requires_grad=True)的去计算他的梯度,并且把它累加保存到x.gard,此时还并未更新其梯度
注意点:
tensor.data:在tensor的require_grad=False,tensor.data和tensor等价
require_grad=True时,tensor.data仅仅是获取tensor中的数据
tensor.numpy():require_grad=True不能够直接转换,需要使用tensor.detach().numpy()
3. 线性回归实现
下面,我们使用一个自定义的数据,来使用torch实现一个简单的线性回归
假设我们的基础模型就是y = wx+b,其中w和b均为参数,我们使用y = 3x+0.8来构造数据x、y,所以最后通过模型应该能够得出w和b应该分别接近3和0.8
- 准备数据
- 计算预测值
- 计算损失,把参数的梯度置为0,进行反向传播
- 更新参数
import torch
import numpy as np
from matplotlib import pyplot as plt
#1. 准备数据 y = 3x+0.8,准备参数
x = torch.rand([50])
y = 3*x + 0.8
w = torch.rand(1,requires_grad=True)
b = torch.rand(1,requires_grad=True)
def loss_fn(y,y_predict):
loss = (y_predict-y).pow(2).mean()
for i in [w,b]:
#每次反向传播前把梯度置为0
if i.grad is not None:
i.grad.data.zero_()
# [i.grad.data.zero_() for i in [w,b] if i.grad is not None]
loss.backward()
return loss.data
def optimize(learning_rate):
# print(w.grad.data,w.data,b.data)
w.data -= learning_rate* w.grad.data
b.data -= learning_rate* b.grad.data
for i in range(3000):
#2. 计算预测值
y_predict = x*w + b
#3.计算损失,把参数的梯度置为0,进行反向传播
loss = loss_fn(y,y_predict)
if i%500 == 0:
print(i,loss)
#4. 更新参数w和b
optimize(0.01)
# 绘制图形,观察训练结束的预测值和真实值
predict = x*w + b #使用训练后的w和b计算预测值
plt.scatter(x.data.numpy(), y.data.numpy(),c = "r")
plt.plot(x.data.numpy(), predict.data.numpy())
plt.show()
print("w",w)
print("b",b)
图形效果如下:
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-EPoO4Eha-1656763156468)(…/images/1.2/线性回归1.png)]
打印w和b,可有
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
图形效果如下:
[外链图片转存中...(img-EPoO4Eha-1656763156468)]
打印w和b,可有
```python
w tensor([2.9280], requires_grad=True)
b tensor([0.8372], requires_grad=True)
可知,w和b已经非常接近原来的预设的3和0.8
边栏推荐
猜你喜欢

Sign up now | oar hacker marathon phase III, waiting for your challenge

微信小程序基本组件使用介绍

Ansible

Qt学习27 应用程序中的主窗口

Thinkcmf6.0 installation tutorial

Cnopendata list data of Chinese colleges and Universities

2022 recurrent training question bank and answers of refrigeration and air conditioning equipment operation

MySQL multi column index (composite index) features and usage scenarios
![[2022 ciscn] replay of preliminary web topics](/img/1c/4297379fccde28f76ebe04d085c5a4.png)
[2022 ciscn] replay of preliminary web topics

自定义类加载器加载网络Class
随机推荐
Shell 脚本的替换功能实现
These five fishing artifacts are too hot! Programmer: I know, delete it quickly!
LeetCode 40:组合总和 II
[advanced digital IC Verification] command query method and common command interpretation of VCs tool
[VHDL parallel statement execution]
C language communication travel card background system
uniapp 移动端强制更新功能
海信电视开启开发者模式
Linux server development, MySQL index principle and optimization
Pytorch(六) —— 模型调优tricks
Linux server development, redis protocol and asynchronous mode
有 Docker 谁还在自己本地安装 Mysql ?
You Li takes you to talk about C language 6 (common keywords)
2022 National latest fire-fighting facility operator (primary fire-fighting facility operator) simulation questions and answers
LeetCode简单题之字符串中最大的 3 位相同数字
太真实了,原来自己一直没有富裕起来是有原因的
[quick start of Digital IC Verification] 15. Basic syntax of SystemVerilog learning 2 (operators, type conversion, loops, task/function... Including practical exercises)
paddlepaddle 29 无模型定义代码下动态修改网络结构(relu变prelu,conv2d变conv3d,2d语义分割模型改为3d语义分割模型)
Custom class loader loads network class
Zsh shell adds automatic completion and syntax highlighting