当前位置:网站首页>【pytorch】pytorch 自动求导、 Tensor 与 Autograd
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
2022-07-31 16:15:00 【Enzo 想砸电脑】
在神经网络中,一个重要内容就是进行参数学习,而参数学习离不开求导。
现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
torch.autograd 包为张量上 所有的操作 提供了自动求导功能
这一篇学习并记录一下 自动求导 的要点。
一、计算图
在整个向前计算过程中,PyTorch采用计算图的形式进行组织,该计算图为动态图,且在每次 前向传播时,将重新构建。其他深度学习架构,如TensorFlow、Keras一般为静态图。

- 计算图是一种有向无环图像,用图形方式来表示算子与变量之间的关系,直观高效。
- 图中 圆形表示变量,矩阵表示算子
- 表达式:z=wx+b,可写成两个表示式: y=wx,z=y+b,
- 其中x、w、b为变量,是用户创建的变量,不依赖于其他变量,故又称 为叶子节点。为计算各叶子节点的梯度,需要把对应的张量参数requires_grad属性设置为 True,这样就可自动跟踪其历史记录。(后面会细说)
- y、z 是计算得到的变量,非叶子节点,z为根节点
- mul和add是算子(或操作或函数)
这些变量及算子就构成了一个完整的计算过程 (或前向传播过程)
二、自动求导要点
为实现对Tensor自动求导,需考虑如下事项:
1)创建叶子节点(Leaf Node)的Tensor,使用requires_grad参数指定是否记录对其 的操作,以便之后利用backward()方法进行梯度求解。requires_grad参数的缺省值为 False,如果要对其求导需设置为True,然后与之有依赖关系的节点会自动变为True。
2)可利用requires_grad_()方法修改Tensor的requires_grad属性(比如一开始在训练阶段,requires_grad 值设置为了True,在测试阶段修改为 False)。可以调用.detach()或 with torch.no_grad():,将不再计算张量的梯度,跟踪张量的历史记录。这点在评估模 型、测试模型阶段中常常用到。
3)通过运算创建的Tensor(即非叶子节点),会自动被赋予grad_fn属性。该属性表 示梯度函数。叶子节点的grad_fn为None。
4)最后得到的Tensor(根节点)执行backward()函数,此时自动计算各变量的梯度。
- 每次反向传播结束,叶子结点的梯度会被清空。如果需要多次反向传播的梯度累加,需要指定backward 中的参数retain_graph=True,这样子节点的梯度是累加的。
- 非叶子节点的梯度backward调用后即被清空
5)backward()函数接收参数,该参数应和调用backward()函数的Tensor的维度相同, 或者是可broadcast的维度。如果求导的Tensor为标量(即一个数字),则backward中的参数可省略。
三、标量反向传播的计算

- 假设x、w、b都是标量,则计算结果 z 也是标量 (z=wx+b)
- 对根节点z调用backward()方法,我们无须对 backward()传入参数
* 这里先提一嘴,后面会说到的是: 如果目标张量对一个非标量调用backward(),则需要传入一个 gradient参数,该参数也是张量,而且需要与调用backward()的张量形状相同。
以下是实现自动求导的主要步骤:
import torch
# 输入张量 x
x = torch.Tensor([2])
# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)
# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad) # False True True
print(y.requires_grad, z.requires_grad ) # True True
# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf) # True True True False False
# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn) # None None None
print(y.grad_fn, z.grad_fn) # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070>
z.backward() # 梯度不会累加
# z.backward(retain_graph=True) # 如果多次使用backward,需要梯度累加,则需要修改参数retain_graph为True
# 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None
print(w.grad,b.grad,x.grad) # tensor([2.]) tensor([1.]) None
#非叶子节点的梯度,执行backward之后,会自动清空
print(y.grad,z.grad) # None None
四、非标量反向传播的计算
边栏推荐
- i.MX6ULL driver development | 33 - NXP original network device driver reading (LAN8720 PHY)
- Delete table data or clear table
- The principle of hough transform detection of straight lines (opencv hough straight line detection)
- Applicable scenario of multi-master replication (2) - client and collaborative editing that require offline operation
- MySQL多表联合查询
- form 表单提交后,使页面不跳转[通俗易懂]
- Visualize GraphQL schemas with GraphiQL
- org.apache.jasperException(could not initialize class org)
- 6-22漏洞利用-postgresql数据库密码破解
- 对话庄表伟:开源第一课
猜你喜欢
随机推荐
Character pointer assignment [easy to understand]
入职一个月反思
mongo enters error
Replication Latency Case (1) - Eventual Consistency
牛客 HJ16 购物单
牛客网刷题(二)
Tencent Cloud Deployment----DevOps
Deployment application life cycle and Pod health check
Bilateral filtering acceleration "recommended collection"
2.索引及调优篇【mysql高级】
动态规划(一)
Kubernetes common commands
JVM parameter analysis Xmx, Xms, Xmn, NewRatio, SurvivorRatio, PermSize, PrintGC "recommended collection"
How to switch remote server in gerrit
.NET 20th Anniversary Interview - Zhang Shanyou: How .NET technology empowers and changes the world
在资源管理类中提供对原始资源的访问——条款15
C语言-函数
ansible学习笔记02
Use of radiobutton
Design and Implementation of Compiler Based on C Language








