当前位置:网站首页>【pytorch】pytorch 自动求导、 Tensor 与 Autograd
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
2022-07-31 16:15:00 【Enzo 想砸电脑】
在神经网络中,一个重要内容就是进行参数学习,而参数学习离不开求导。
现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
torch.autograd 包为张量上 所有的操作 提供了自动求导功能
这一篇学习并记录一下 自动求导 的要点。
一、计算图
在整个向前计算过程中,PyTorch采用计算图的形式进行组织,该计算图为动态图,且在每次 前向传播时,将重新构建。其他深度学习架构,如TensorFlow、Keras一般为静态图。

- 计算图是一种有向无环图像,用图形方式来表示算子与变量之间的关系,直观高效。
- 图中 圆形表示变量,矩阵表示算子
- 表达式:z=wx+b,可写成两个表示式: y=wx,z=y+b,
- 其中x、w、b为变量,是用户创建的变量,不依赖于其他变量,故又称 为叶子节点。为计算各叶子节点的梯度,需要把对应的张量参数requires_grad属性设置为 True,这样就可自动跟踪其历史记录。(后面会细说)
- y、z 是计算得到的变量,非叶子节点,z为根节点
- mul和add是算子(或操作或函数)
这些变量及算子就构成了一个完整的计算过程 (或前向传播过程)
二、自动求导要点
为实现对Tensor自动求导,需考虑如下事项:
1)创建叶子节点(Leaf Node)的Tensor,使用requires_grad参数指定是否记录对其 的操作,以便之后利用backward()方法进行梯度求解。requires_grad参数的缺省值为 False,如果要对其求导需设置为True,然后与之有依赖关系的节点会自动变为True。
2)可利用requires_grad_()方法修改Tensor的requires_grad属性(比如一开始在训练阶段,requires_grad 值设置为了True,在测试阶段修改为 False)。可以调用.detach()或 with torch.no_grad():,将不再计算张量的梯度,跟踪张量的历史记录。这点在评估模 型、测试模型阶段中常常用到。
3)通过运算创建的Tensor(即非叶子节点),会自动被赋予grad_fn属性。该属性表 示梯度函数。叶子节点的grad_fn为None。
4)最后得到的Tensor(根节点)执行backward()函数,此时自动计算各变量的梯度。
- 每次反向传播结束,叶子结点的梯度会被清空。如果需要多次反向传播的梯度累加,需要指定backward 中的参数retain_graph=True,这样子节点的梯度是累加的。
- 非叶子节点的梯度backward调用后即被清空
5)backward()函数接收参数,该参数应和调用backward()函数的Tensor的维度相同, 或者是可broadcast的维度。如果求导的Tensor为标量(即一个数字),则backward中的参数可省略。
三、标量反向传播的计算

- 假设x、w、b都是标量,则计算结果 z 也是标量 (z=wx+b)
- 对根节点z调用backward()方法,我们无须对 backward()传入参数
* 这里先提一嘴,后面会说到的是: 如果目标张量对一个非标量调用backward(),则需要传入一个 gradient参数,该参数也是张量,而且需要与调用backward()的张量形状相同。
以下是实现自动求导的主要步骤:
import torch
# 输入张量 x
x = torch.Tensor([2])
# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)
# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad) # False True True
print(y.requires_grad, z.requires_grad ) # True True
# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf) # True True True False False
# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn) # None None None
print(y.grad_fn, z.grad_fn) # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070>
z.backward() # 梯度不会累加
# z.backward(retain_graph=True) # 如果多次使用backward,需要梯度累加,则需要修改参数retain_graph为True
# 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None
print(w.grad,b.grad,x.grad) # tensor([2.]) tensor([1.]) None
#非叶子节点的梯度,执行backward之后,会自动清空
print(y.grad,z.grad) # None None
四、非标量反向传播的计算
边栏推荐
- Visualize GraphQL schemas with GraphiQL
- Snake Project (Simple)
- OPPO在FaaS领域的探索与思考
- .NET 20th Anniversary Interview - Zhang Shanyou: How .NET technology empowers and changes the world
- Implementing distributed locks based on Redis (SETNX), case: Solving oversold orders under high concurrency
- What is the difference between BI software in the domestic market?
- 【7.29】Code Source - 【Arrangement】【Stone Game II】【Cow and Snacks】【Minimum Number of Spawns】【Sequence】
- 2020微信小程序反编译教程(小程序反编译源码能用吗)
- ansible study notes 02
- 宁波大学NBU IT项目管理期末考试知识点整理
猜你喜欢

宁波大学NBU IT项目管理期末考试知识点整理

长得很怪的箱图

基于Redis(SETNX)实现分布式锁,案例:解决高并发下的订单超卖,秒杀

After Grafana is installed, the web opens and reports an error

The new BMW 3 Series is on the market, with safety and comfort

Implementing DDD based on ABP

mysql black window ~ build database and build table

苹果官网样式调整 结账时产品图片“巨大化”

Visualize GraphQL schemas with GraphiQL

使用 Postman 工具高效管理和测试 SAP ABAP OData 服务的试读版
随机推荐
长得很怪的箱图
Stuck in sill idealTree buildDeps during npm installation, npm installation is slow, npm installation is stuck in one place
Use of radiobutton
Design and Implementation of Compiler Based on C Language
[Meetup Preview] OpenMLDB+OneFlow: Link feature engineering to model training to accelerate machine learning model development
牛客 HJ18 识别有效的IP地址和掩码并进行分类统计
JVM parameter analysis Xmx, Xms, Xmn, NewRatio, SurvivorRatio, PermSize, PrintGC "recommended collection"
全新宝马3系上市,安全、舒适一个不落
How Redis handles concurrent access
字符指针赋值[通俗易懂]
Internet banking stolen?This article tells you how to use online banking safely
Bilateral filtering acceleration "recommended collection"
What is the difference between BI software in the domestic market?
复杂高维医学数据挖掘与疾病风险分类研究
【7.28】代码源 - 【Fence Painting】【合适数对(数据加强版)】
what exactly is json (c# json)
The new BMW 3 Series is on the market, with safety and comfort
2022年必读的12本机器学习书籍推荐
动态规划(一)
ML.NET相关资源整理