当前位置:网站首页>玩转Pytorch的Function类
玩转Pytorch的Function类
2022-06-10 16:53:00 【武乐乐~】
前言
pytorch提供了autograd自动求导机制,而autograd实现自动求导实质上通过Function类实现的。而习惯搭积木的伙伴平时也不写backward。造成需要拓展算子情况便会手足无措。本文从简单例子入手,学习实现一个Function类最基本的要素,同时还会涉及一些注意事项,最后在结合一个实战来学习Function类的使用。
1、y=w*x+b
import torch
from torch.autograd import Function
# y = w*x + b 的一个前向传播和反向求导
class Mul(Function):
@staticmethod
def forward(ctx, w, x, b, x_requires_grad = True): # ctx可以理解为元祖,用来存储梯度的中间缓存变量。
ctx.save_for_backward(w,b) # 因为dy/dx = w; dy/dw = x ; dy/db = 1;为了后续反向传播需要保存中间变量w,x
output = w*x + b
return output
@staticmethod
def backward(ctx,grad_outputs): # 此处grad_outputs 具体问题具体分析
w = ctx.saved_tensors[0] # 取出ctx中保存的 w = 2
b = ctx.saved_tensors[1] # 取出ctx中保存的 b = 3
grad_w = grad_outputs * x # 1 * 1 = 1
grad_x = grad_outputs * w # 1 * 2 = 2
grad_b = grad_outputs * 1 # 1 * 1 = 1
return grad_w, grad_x, grad_b, None # 返回的参数和forward的参数一一对应,对于参数x_requires_grad不必求梯度则直接返回None。
if __name__ == '__main__':
x = torch.tensor(1.,requires_grad=True)
w = torch.tensor(2.,requires_grad=True)
b = torch.tensor(3., requires_grad=True)
y = Mul.apply(w,x,b) # y = w*x + b = 2*1 + 3 = 5
print('forward:', y)
# 写法一
loss = y.sum() # 转成标量
loss.backward() # 反向传播:因为 loss = sum(y),故grad_outputs = dloss/dy = 1,可以省略不写
print('写法一的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
这里简单说下:代码中注释有问题欢迎留言评论。其中y=w*x+b。前向传播容易理解。这里令人困惑的应该是ctx这个东西,其实可以将其理解为一个元祖,通过方法save_for_backward()来保存前向传播的中间缓存变量,为后续反向传播提供条件。而在反向传播中,首先从ctx中通过调用方法saved_tensors[]来得到w,b。之后各个参数的梯度:dy/dx = w; dy/dw = x; dy/db = 1。
另外,在反向传播中,令人困惑就是参数grad_outputs。其实这个参数的值跟类调用完之后有关。在代码中,使用loss.backward(),可以看见传入的参数是个空。这是因为在计算完前向传播得到y之后,loss = y.sum(),即grad_outputs = dloss/dy = 1; 而在torch中,可以省略不写。故此处的grad_outputs=1.
当然,我们也可以明示的传参进去。
# 写法一
loss1 = y.sum() # 转成标量
loss1.backward() # 反向传播:因为 loss = sum(y),故grad_outputs = dloss/dy = 1,可以省略不写
print('写法一的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
# 写法二
loss2 = y.sum()
loss2.backward(torch.tensor(1.))
print('写法二的梯度:',x.grad, w.grad, b.grad) # tensor(4.) tensor(2.) tensor(2.)
但是此时报错了,报错信息如下:
RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.
大体意思说同一个计算图不能反向传播两次。因为,在调用第一次backward之后,计算图就销毁了。所以需要通过设置参数retain_graph参数保存计算图,更改后代码如下:
# 写法一
loss1 = y.sum() # 转成标量
loss1.backward(retain_graph = True) # 反向传播:因为 loss = sum(y),故grad_outputs = dloss/dy = 1,可以省略不写
print('写法一的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
# 写法二
loss2 = y.sum()
loss2.backward(torch.tensor(1.))
print('写法二的梯度:',x.grad, w.grad, b.grad) # tensor(4.) tensor(2.) tensor(2.)
不幸的是,此时写法二和写法一的梯度计算结果不一致,发现写法二的梯度是写法一梯度的两倍。是因为在pytorch中两次不同loss在反传梯度时在叶子节点梯度是累加的。因此,我们在损失二传播之间需要将损失一的梯度清0。代码如下:
# 写法一
loss1 = y.sum() # 转成标量
loss1.backward(retain_graph = True) # 反向传播:因为 loss = sum(y),故grad_outputs = dloss/dy = 1,可以省略不写
print('写法一的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
# 叶子节点梯度清0
x.grad.zero_()
w.grad.zero_()
b.grad.zero_()
# 写法二
loss2 = y.sum()
loss2.backward(torch.tensor(1.))
print('写法二的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
OK,大功告成。完整代码如下:
import torch
from torch.autograd import Function
# y = w*x + b 的一个前向传播和反向求导
class Mul(Function):
@staticmethod
def forward(ctx, w, x, b, x_requires_grad = True): # ctx可以理解为元祖,用来存储梯度的中间缓存变量。
ctx.save_for_backward(w,b) # 因为dy/dx = w; dy/dw = x ; dy/db = 1;为了后续反向传播需要保存中间变量w,x
output = w*x + b
return output
@staticmethod
def backward(ctx,grad_outputs): # 此处grad_outputs 具体问题具体分析
w = ctx.saved_tensors[0] # 取出ctx中保存的 w = 2
b = ctx.saved_tensors[1] # 取出ctx中保存的 b = 3
grad_w = grad_outputs * x # 1 * 1 = 1
grad_x = grad_outputs * w # 1 * 2 = 2
grad_b = grad_outputs * 1 # 1 * 1 = 1
return grad_w, grad_x, grad_b, None # 返回的参数和forward的参数一一对应,对于参数x_requires_grad不必求梯度则直接返回None。
if __name__ == '__main__':
x = torch.tensor(1.,requires_grad=True)
w = torch.tensor(2.,requires_grad=True)
b = torch.tensor(3., requires_grad=True)
y = Mul.apply(w,x,b) # y = w*x + b = 2*1 + 3 = 5
print('forward:', y)
# 写法一
loss1 = y.sum() # 转成标量
loss1.backward(retain_graph = True) # 反向传播:因为 loss = sum(y),故grad_outputs = dloss/dy = 1,可以省略不写
print('写法一的梯度:',x.grad, w.grad, b.grad) # tensor(2.) tensor(1.) tensor(1.)
# 叶子节点梯度清0
x.grad.zero_()
w.grad.zero_()
b.grad.zero_()
# 写法二
loss2 = y.sum()
loss2.backward(torch.tensor(1.))
print('写法二的梯度:',x.grad, w.grad, b.grad) # tensor(4.) tensor(2.) tensor(2.)
2、进阶:y=exp(x)*2
import torch
from torch.autograd import Function
class Exp(Function):
@staticmethod
def forward(ctx,x):
output = x.exp()
ctx.save_for_backward(output) # dy/dx = exp(x)
return output
@staticmethod
def backward(ctx, grad_outputs): # dloss/dx = grad_outputs* exp(x)
output = ctx.saved_tensors[0]
return output*grad_outputs
if __name__ == '__main__':
x = torch.tensor(1.,requires_grad=True)
y = Exp.apply(x)
print(y)
y = y * 2
loss = y.sum()
loss.backward()
print(x.grad)
唯一需要注意就是:dloss/dy = 1 * 2 = 2;因为loss = sum(2y)。
3、实战:GuideReLU函数
ReLU函数:y=max(x,0),反传梯度时仅x>0的位置才有梯度,且梯度值为1.因为y=x,所以dy/dx=1;而GuideReLU是在ReLU基础上,不仅x>0位置才能反传梯度,还要满足梯度>0位置才能反传梯度。dloss/dx = dloss/dy * (x>0) * (grad_output>0)。代码如下:
import torch
from torch.autograd import Function
class GuideReLU(Function):
@staticmethod
def forward(ctx,input):
output = torch.clamp(input,min=0)
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_outputs): # dloss/dx = dloss/dy * !(x > 0) * (dloss/dy > 0)
output = ctx.saved_tensors[0] # dloss/dy
return grad_outputs * (output>0).float()* (grad_outputs>0).float()
if __name__ == '__main__':
x = torch.randn(2,3,requires_grad=True)
print('input:',x)
y = GuideReLU.apply(x)
print('forward:',y)
grad_y = torch.randn(2,3)
y.backward(grad_y) # 此处接收一个梯度数值,即grad_outputs
print('grad_y:',grad_y) # 即只有当输入x和返回梯度grad_y同时>0位置才有梯度值。
print('grad_x:',x.grad)
4、梯度检查:torch.autograd.gradcheck()
pytorch提供了一个梯度检查api,可以很方便检测自己写的传播是否正确。
import torch
from torch.autograd import Function
class Sigmoid(Function):
@staticmethod
def forward(ctx, x):
output = 1 / (1 + torch.exp(-x))
ctx.save_for_backward(output)
return output
@staticmethod
def backward(ctx, grad_output):
output, = ctx.saved_tensors
grad_x = output * (1 - output) * grad_output
return grad_x
test_input = torch.randn(4, requires_grad=True) # tensor([-0.4646, -0.4403, 1.2525, -0.5953], requires_grad=True)
print(torch.autograd.gradcheck(Sigmoid.apply, (test_input,), eps=1e-3))
总结
本篇是介绍pytorch反向传导的第一篇,后续会介绍拓展C++/CUDA算子。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。
边栏推荐
- AIChE | ab initio drug design framework integrating mathematical programming method and deep learning model
- Snabbdom 虚拟 dom(一)
- com. netflix. client. ClientException: Load balancer does not have available server for client: userser
- THE LOTTERY TICKET HYPOTHESIS: FINDING SPARSE, TRAINABLE NEURAL NETWORKS论文笔记
- B站不想成为“良心版爱优腾”
- 海外数据中心需要为不可预测的灾难做好准备
- [play with Huawei cloud] Kunpeng devkit migration practice
- Knowledge-based BERT: 像计算化学家一样提取分子特征的方法
- Importerror: libgl.so.1: cannot open shared object file: no such file or directory
- Snabbdom virtual DOM (I)
猜你喜欢

2022G1工业锅炉司炉考题及在线模拟考试

Only three steps are needed to learn how to use low code thingjs to connect with Sen data Dix data

Nacos registry

运筹说 第64期丨动态规划奠基人——理查德·贝尔曼

2022年T电梯修理考试题模拟考试题库及在线模拟考试

Eliminate if Five ways of else

Nacos注册中心

Chongqing's first sci tech Innovation Board IPO is coming

丢失的遗传力--Missing heritability

2022 version of idea graphical interface GUI garbled code solution super detailed simple version
随机推荐
How MySQL modifies field type and field length
Thread interview related questions
Photoshop如何打开、编辑和导出Webp格式图片的方法
When V-IF and V-for need to be used at the same time
Leetcode String to integer(Atoi)
Gateway service gateway
2022年茶艺师(中级)操作证考试题库及模拟考试
Web3 is the most complete money making secret script. Just read this one
Draw confusion matrix
重庆第一个科创板IPO,来了
Designing drugs with code: are we here yet?
如何运行plink软件--三种方法
Snabbdom 虚拟 dom(一)
IPO治不了威马的杂症?
Feign based remote call
Nat. Rev. Drug Discov. | Application of AI in small molecule drug discovery: an upcoming wave?
C#_串口通信项目
2022年茶艺师(中级)操作证考试题库及模拟考试
mapbox-gl开发教程(十一):加载线图层
Why does the universe limit its maximum speed to the speed of light