当前位置:网站首页>【pytorch】pytorch 自动求导、 Tensor 与 Autograd
【pytorch】pytorch 自动求导、 Tensor 与 Autograd
2022-07-31 16:15:00 【Enzo 想砸电脑】
在神经网络中,一个重要内容就是进行参数学习,而参数学习离不开求导。
现在大部分深度学习架构都有自动求导的功能,torch.autograd包 就是用来自动求导的。
torch.autograd 包为张量上 所有的操作 提供了自动求导功能
这一篇学习并记录一下 自动求导 的要点。
一、计算图
在整个向前计算过程中,PyTorch采用计算图的形式进行组织,该计算图为动态图,且在每次 前向传播时,将重新构建。其他深度学习架构,如TensorFlow、Keras一般为静态图。
- 计算图是一种有向无环图像,用图形方式来表示算子与变量之间的关系,直观高效。
- 图中 圆形表示变量,矩阵表示算子
- 表达式:z=wx+b,可写成两个表示式: y=wx,z=y+b,
- 其中x、w、b为变量,是用户创建的变量,不依赖于其他变量,故又称 为叶子节点。为计算各叶子节点的梯度,需要把对应的张量参数requires_grad属性设置为 True,这样就可自动跟踪其历史记录。(后面会细说)
- y、z 是计算得到的变量,非叶子节点,z为根节点
- mul和add是算子(或操作或函数)
这些变量及算子就构成了一个完整的计算过程 (或前向传播过程)
二、自动求导要点
为实现对Tensor自动求导,需考虑如下事项:
1)创建叶子节点(Leaf Node)的Tensor,使用requires_grad参数指定是否记录对其 的操作,以便之后利用backward()方法进行梯度求解。requires_grad参数的缺省值为 False,如果要对其求导需设置为True,然后与之有依赖关系的节点会自动变为True。
2)可利用requires_grad_()方法修改Tensor的requires_grad属性(比如一开始在训练阶段,requires_grad 值设置为了True,在测试阶段修改为 False)。可以调用.detach()或 with torch.no_grad():,将不再计算张量的梯度,跟踪张量的历史记录。这点在评估模 型、测试模型阶段中常常用到。
3)通过运算创建的Tensor(即非叶子节点),会自动被赋予grad_fn属性。该属性表 示梯度函数。叶子节点的grad_fn为None。
4)最后得到的Tensor(根节点)执行backward()函数,此时自动计算各变量的梯度。
- 每次反向传播结束,叶子结点的梯度会被清空。如果需要多次反向传播的梯度累加,需要指定backward 中的参数retain_graph=True,这样子节点的梯度是累加的。
- 非叶子节点的梯度backward调用后即被清空
5)backward()函数接收参数,该参数应和调用backward()函数的Tensor的维度相同, 或者是可broadcast的维度。如果求导的Tensor为标量(即一个数字),则backward中的参数可省略。
三、标量反向传播的计算
- 假设x、w、b都是标量,则计算结果 z 也是标量 (z=wx+b)
- 对根节点z调用backward()方法,我们无须对 backward()传入参数
* 这里先提一嘴,后面会说到的是: 如果目标张量对一个非标量调用backward(),则需要传入一个 gradient参数,该参数也是张量,而且需要与调用backward()的张量形状相同。
以下是实现自动求导的主要步骤:
import torch
# 输入张量 x
x = torch.Tensor([2])
# 初始化 权重参数w, 偏移量b,并设置 require_grad 属性为 True, 为自动求导
w = torch.randn(1, requires_grad=True)
b = torch.randn(1, requires_grad=True)
# 实现向前传播
y = torch.mul(w, x)
z = torch.add(y, b)
# 分别查看叶子节点 x, w, b 和 非叶子节点 y、z 的require_grad属性
print(x.requires_grad, w.requires_grad, b.requires_grad) # False True True
print(y.requires_grad, z.requires_grad ) # True True
# 查看各节点是否为叶子节点
print(x.is_leaf, w.is_leaf, b.is_leaf, y.is_leaf, z.is_leaf) # True True True False False
# 分别查看 叶子节点 和 非叶子节点 的 grad_fn 属性
print(x.grad_fn, w.grad_fn, b.grad_fn) # None None None
print(y.grad_fn, z.grad_fn) # <MulBackward0 object at 0x7f8ac1303910> <AddBackward0 object at 0x7f8ac1303070>
z.backward() # 梯度不会累加
# z.backward(retain_graph=True) # 如果多次使用backward,需要梯度累加,则需要修改参数retain_graph为True
# 查看叶子节点的梯度,x是叶子节点但它无须求导,故其梯度为None
print(w.grad,b.grad,x.grad) # tensor([2.]) tensor([1.]) None
#非叶子节点的梯度,执行backward之后,会自动清空
print(y.grad,z.grad) # None None
四、非标量反向传播的计算
边栏推荐
- Graham's Scan method for solving convex hull problems
- Delete table data or clear table
- tooltips使用教程(鼠标悬停时显示提示)
- 牛客网刷题(二)
- 百度网盘网页版加速播放(有可用的网站吗)
- Deployment application life cycle and Pod health check
- [Meetup Preview] OpenMLDB+OneFlow: Link feature engineering to model training to accelerate machine learning model development
- 11 pinia use
- Oracle动态注册非1521端口
- Internet banking stolen?This article tells you how to use online banking safely
猜你喜欢
How Redis handles concurrent access
LevelSequence源码分析
EF Core 2.2中将ORM框架生成的SQL语句输出到控制台
Tencent Cloud Deployment----DevOps
mongo enters error
How C programs run 01 - the composition of ordinary executable files
MySQL基础篇【单行函数】
研发过程中的文档管理与工具
WPF project - basic usage of controls entry, you must know XAML
Kubernetes principle analysis and practical application manual, too complete
随机推荐
npm安装时卡在sill idealTree buildDeps,npm安装速度慢,npm安装卡在一个地方不动
type of timer
Internet banking stolen?This article tells you how to use online banking safely
基于C语言的编译器设计与实现
arm按键控制led灯闪烁(嵌入式按键实验报告)
Snake Project (Simple)
2.索引及调优篇【mysql高级】
mysql black window ~ build database and build table
Premiere Pro 2022 for (pr 2022)v22.5.0
牛客 HJ18 识别有效的IP地址和掩码并进行分类统计
ML.NET相关资源整理
小程序:matlab解微分方程「建议收藏」
01 Encounter typescript, build environment
在资源管理类中提供对原始资源的访问——条款15
SringMVC中个常见的几个问题
Kubernetes principle analysis and practical application manual, too complete
Browser's built-in color picker
2022年整理LeetCode最新刷题攻略分享(附中文详细题解)
After Grafana is installed, the web opens and reports an error
Linux check redis version (check mongodb version)