当前位置:网站首页>PyTorch 学习笔记 4 —— 自动计算梯度下降 AUTOGRAD
PyTorch 学习笔记 4 —— 自动计算梯度下降 AUTOGRAD
2022-07-28 05:24:00 【我有两颗糖】
文章目录
1. Tensors, Functions and Computational graph
torch.autograd 可以实现计算图中每一个单元的导数的自动计算,比如下面的计算图中w 和 b的导数计算按:

可以这样实现:
import torch
x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
# or (z = x.matmul(w) + b)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
参数 requires_grad=True 指定在进行前向计算时自动差分求导,计算完 Z 和 loss 后,tensor z 和 loss 会自动创建一个属性 grad_fn,它是一个 Function 对象,用于计算前向前向传播和 back propagation 求导。
print(f'Gradient function for z = {
z.grad_fn}')
print(f'Gradient function for loss = {
loss.grad_fn}')
2. Computing Gradients
通过调用 loss.backward() 可以做一次 BP,可以自动计算 ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss 和 ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss,通过 w.grad 和 b.grad 获取
loss 对 w 和 b 的微分:
loss.backward()
print(w.grad)
print(b.grad)
# tensor([[0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266]])
# tensor([0.1814, 0.0460, 0.3266])
注意
1 . We can only obtain the grad properties for the leaf nodes of the computational graph, which have requires_grad property set to True. For all other nodes in our graph, gradients will not be available.
2 . We can only perform gradient calculations using backward once on a given graph, for performance reasons. If we need to do several backward calls on the same graph, we need to pass retain_graph=True to the backward call.
3. Disabling Gradient Tracking
所有具有 requires_grad=True 的 tensors 都会记录计算过程并支持梯度计算,但有时我们不希望计算梯度信息,比如当我们指向让模型预测一些样本时,此时可以使用 torch.no_grad() 来停止所有梯度计算:
z = torch.matmul(x, w) + b
print(z.requires_grad) # True
with torch.no_grad():
z = torch.matmul(x, w) + b
print(z.requires_grad) # False
此外,也可以使用 z.detach() :
z = torch.matmul(x, w) + b
print(z.detach().requires_grad) # False
判断模型性能时,需要在 with torch.no_grad() 条件下进行!
其他场景:frozen 网络中的一部分、模型微调、加速前向传播
4. forward & backward
前向传播
1 . 计算前向传播的结果,保存每一个 operation 的 grad_fn
后向传播
1 . 计算每一个 .grad_fn 对应的梯度值
2 . 将梯度值累加到对应 tensor 的 .grad 属性上
3 . 使用链式法则计算叶节点的梯度
REFERENCE:
1 . https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html#disabling-gradient-tracking
更多内容参考:PyTorch 学习笔记
边栏推荐
- AEM-TESTpro K50和南粤勘察结下的缘分
- 天线效应解决办法
- Analysis of MOSFET damage at the moment of power failure of isolated power supply
- 硬件电路设计学习笔记2--降压电源电路
- (PHP graduation project) obtained based on thinkphp5 campus news release management system
- 杭州某公司福禄克FLUKE DTX-SFM2单模模块-修复案例
- A comparative study of backdoor attack and counter sample attack
- WebService出错 Maximum message size quota for incoming messages (65536) has been exceeded.已超过传入消息(655
- WebService error maximum message size quota for incoming messages (65536) has been exceeded
- Terminal resistance detailed signal complete series hardware learning notes 7
猜你喜欢

天线效应解决办法

低功耗设计-Power Switch

1、 Amd - openvino environment configuration

ASP.NET 读数据库绑定到 TreeView 递归方式

AEM testpro K50 and south Guangdong survey

线缆测试中遇到苦恼---某厂商案例分析?

LED发光二极管选型-硬件学习笔记3

初学者进行传感器选型

电快速脉冲群(EFT)设计-EMC系列 硬件设计笔记4

Agilent Agilent e5071 test impedance and attenuation are normal, except crosstalk ng--- Repair plan
随机推荐
光伏发电系统——mppt最大功率点追踪
福禄克DSX2-5000 网络线缆测试仪为什么每年都要校准一次?
AEM-TESTpro K50和南粤勘察结下的缘分
Surge impact immunity experiment (surge) -emc series Hardware Design Notes 6
Reversible digital watermarking method based on histogram modification
Nsctf web Title writeup
福禄克DTX-1800其配件DTX-CHA002通道适配器CHANNEL更换RJ45插座小记
关于接触器线圈控制电路设计分析
set_ case_ analysis
Learning notes of hardware circuit design 1 -- temperature rise design
一个票据打印实例
说说ESXi虚拟交换机和端口组的“混杂模式”
4、 Model optimizer and inference engine
Synopsys Multivoltage Flow
Summary of command injection bypass methods
Analysis of MOSFET damage at the moment of power failure of isolated power supply
RS232 RS485 RS422 communication learning and notes
Example of frameset usage
生活随机-1
初学者进行传感器选型