当前位置:网站首页>【pytorch】pytorch 自动求导、 Tensor 与 Autograd
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
2022-07-31 16:15:00 【Enzo 想砸电脑】
在神经网络中,一个重要内容就是进行参数学习,而参数学习离不开求导。
现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
torch.autograd 包为张量上 所有的操作 提供了自动求导功能
这一篇学习并记录一下 自动求导 的要点。
一、计算图
在整个向前计算过程中,PyTorch采用计算图的形式进行组织,该计算图为动态图,且在每次 前向传播时,将重新构建。其他深度学习架构,如TensorFlow、Keras一般为静态图。
- 计算图是一种有向无环图像,用图形方式来表示算子与变量之间的关系,直观高效。
- 图中 圆形表示变量,矩阵表示算子
- 表达式:z=wx+b,可写成两个表示式: y=wx,z=y+b,
- 其中x、w、b为变量,是用户创建的变量,不依赖于其他变量,故又称 为叶子节点。为计算各叶子节点的梯度,需要把对应的张量参数requires_grad属性设置为 True,这样就可自动跟踪其历史记录。(后面会细说)
- y、z 是计算得到的变量,非叶子节点,z为根节点
- mul和add是算子(或操作或函数)
这些变量及算子就构成了一个完整的计算过程 (或前向传播过程)
二、自动求导要点
为实现对Tensor自动求导,需考虑如下事项:
1)创建叶子节点(Leaf Node)的Tensor,使用requires_grad参数指定是否记录对其 的操作,以便之后利用backward()方法进行梯度求解。requires_grad参数的缺省值为 False,如果要对其求导需设置为True,然后与之有依赖关系的节点会自动变为True。
2)可利用requires_grad_()方法修改Tensor的requires_grad属性(比如一开始在训练阶段,requires_grad 值设置为了True,在测试阶段修改为 False)。可以调用.detach()或 with torch.no_grad():,将不再计算张量的梯度,跟踪张量的历史记录。这点在评估模 型、测试模型阶段中常常用到。
3)通过运算创建的Tensor(即非叶子节点),会自动被赋予grad_fn属性。该属性表 示梯度函数。叶子节点的grad_fn为None。
4)最后得到的Tensor(根节点)执行backward()函数,此时自动计算各变量的梯度。
- 每次反向传播结束,叶子结点的梯度会被清空。如果需要多次反向传播的梯度累加,需要指定backward 中的参数retain_graph=True,这样子节点的梯度是累加的。
- 非叶子节点的梯度backward调用后即被清空
5)backward()函数接收参数,该参数应和调用backward()函数的Tensor的维度相同, 或者是可broadcast的维度。如果求导的Tensor为标量(即一个数字),则backward中的参数可省略。
三、标量反向传播的计算
- 假设x、w、b都是标量,则计算结果 z 也是标量 (z=wx+b)
- 对根节点z调用backward()方法,我们无须对 backward()传入参数
* 这里先提一嘴,后面会说到的是: 如果目标张量对一个非标量调用backward(),则需要传入一个 gradient参数,该参数也是张量,而且需要与调用backward()的张量形状相同。
以下是实现自动求导的主要步骤:
import torch
# 输入张量 x
x = torch.Tensor([2])
# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)
# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad) # False True True
print(y.requires_grad, z.requires_grad ) # True True
# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf) # True True True False False
# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn) # None None None
print(y.grad_fn, z.grad_fn) # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070>
z.backward() # 梯度不会累加
# z.backward(retain_graph=True) # 如果多次使用backward,需要梯度累加,则需要修改参数retain_graph为True
# 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None
print(w.grad,b.grad,x.grad) # tensor([2.]) tensor([1.]) None
#非叶子节点的梯度,执行backward之后,会自动清空
print(y.grad,z.grad) # None None
四、非标量反向传播的计算
边栏推荐
猜你喜欢
t-sne 数据可视化网络中的部分参数+
"Autumn Recruitment Series" MySQL Interview Core 25 Questions (with answers)
2022年整理LeetCode最新刷题攻略分享(附中文详细题解)
.NET 20周年专访 - 张善友:.NET 技术是如何赋能并改变世界的
基于Redis(SETNX)实现分布式锁,案例:解决高并发下的订单超卖,秒杀
How C programs run 01 - the composition of ordinary executable files
利用PHP开发具有注册、登陆、文件上传、发布动态功能的网站
使用 Postman 工具高效管理和测试 SAP ABAP OData 服务的试读版
MySQL基础篇【单行函数】
【TypeScript】深入学习TypeScript类型操作
随机推荐
mysql黑窗口~建库建表
Tencent Cloud Deployment----DevOps
2022年整理LeetCode最新刷题攻略分享(附中文详细题解)
Kubernetes principle analysis and practical application manual, too complete
Kubernetes常用命令
How Redis handles concurrent access
Character pointer assignment [easy to understand]
The new BMW 3 Series is on the market, with safety and comfort
【7.29】代码源 - 【排列】【石子游戏 II】【Cow and Snacks】【最小生成数】【数列】
JVM parameter analysis Xmx, Xms, Xmn, NewRatio, SurvivorRatio, PermSize, PrintGC "recommended collection"
外媒所言非虚,苹果降价或许是真的在清库存
i.MX6ULL驱动开发 | 33 - NXP原厂网络设备驱动浅读(LAN8720 PHY)
贪吃蛇项目(简单)
基于ABP实现DDD
入职一个月反思
type of timer
基于C语言的编译器设计与实现
Dialogue with Zhuang Biaowei: The first lesson of open source
Premiere Pro 2022 for (pr 2022)v22.5.0
Qt实战案例(54)——利用QPixmap设计图片透明度