当前位置:网站首页>Pytorch learning note 4 - automatic calculation of gradient descent autograd
Pytorch learning note 4 - automatic calculation of gradient descent autograd
2022-07-28 06:28:00 【I have two candies】
List of articles
1. Tensors, Functions and Computational graph
torch.autograd It can automatically calculate the derivative of each element in the calculation diagram , For example, in the following calculation diagram w and b The derivative of is calculated by :

It can be realized in this way :
import torch
x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
# or (z = x.matmul(w) + b)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
Parameters requires_grad=True Specify automatic differential derivation when performing forward calculation , After calculating Z and loss after ,tensor z and loss An attribute will be automatically created grad_fn, It's a Function object , Used to calculate forward propagation and back propagation Derivation .
print(f'Gradient function for z = {
z.grad_fn}')
print(f'Gradient function for loss = {
loss.grad_fn}')
2. Computing Gradients
By calling loss.backward() Can do it once BP, It can be calculated automatically ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss and ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss, adopt w.grad and b.grad obtain
loss Yes w and b Differential of :
loss.backward()
print(w.grad)
print(b.grad)
# tensor([[0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266]])
# tensor([0.1814, 0.0460, 0.3266])
Be careful
1 . We can only obtain the grad properties for the leaf nodes of the computational graph, which have requires_grad property set to True. For all other nodes in our graph, gradients will not be available.
2 . We can only perform gradient calculations using backward once on a given graph, for performance reasons. If we need to do several backward calls on the same graph, we need to pass retain_graph=True to the backward call.
3. Disabling Gradient Tracking
All have requires_grad=True Of tensors Will record the calculation process and support gradient calculation , But sometimes we don't want to calculate gradient information , For example, when we point to let the model predict some samples , You can use torch.no_grad() To stop all gradient calculations :
z = torch.matmul(x, w) + b
print(z.requires_grad) # True
with torch.no_grad():
z = torch.matmul(x, w) + b
print(z.requires_grad) # False
Besides , You can also use z.detach() :
z = torch.matmul(x, w) + b
print(z.detach().requires_grad) # False
When judging the performance of the model , Need to be in with torch.no_grad() Under certain conditions !
Other scenarios :frozen Part of the network 、 Fine tuning the model 、 Accelerate forward propagation
4. forward & backward
Forward propagation
1 . Calculate the result of forward propagation , Save every operation Of grad_fn
Back propagation
1 . Calculate each .grad_fn Corresponding gradient value
2 . Add the gradient value to the corresponding tensor Of .grad Attribute
3 . Use the chain rule to calculate the gradient of leaf nodes
REFERENCE:
1 . https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html#disabling-gradient-tracking
For more information, please refer to :PyTorch Learning notes
边栏推荐
- Esxi on ARM v1.2 (updated in November 2020)
- Shuffle Net_ v1-shuffle_ v2
- Esxi on arm 10/22 update
- mysql删表不删库
- 论文神器 VS Code + LaTex + LaTex Workshop
- Transformer 自注意力机制 及完整代码实现
- Trouble encountered in cable testing -- a case study of a manufacturer?
- Agilent Agilent e5071 test impedance and attenuation are normal, except crosstalk ng--- Repair plan
- Arduino reads the analog voltage_ How mq2 gas / smoke sensor works and its interface with Arduino
- Fluke dtx-sfm2 single mode module of a company in Hangzhou - repair case
猜你喜欢

AEM testpro K50 and south Guangdong survey

Efficient Net_V2

EXFO 730c optical time domain reflectometer only has IOLm optical eye to upgrade OTDR (open OTDR permission)

Bag of tricks training convolution network skills

Overall understanding of PLC

PyTorch 学习笔记 3 —— DATASETS & DATALOADERS & TRANSFORMS

vi和vim命令

EfficientNET_V1

雷达成像 Matlab 仿真 3 —— 多目标检测

EMC experiment practical case ESD electrostatic experiment
随机推荐
Talk about the "hybrid mode" of esxi virtual switch and port group
TCL and eltcl? Cdnext and CMRL?
In asp Usage of cookies in. Net
Trouble encountered in cable testing -- a case study of a manufacturer?
set_case_analysis
简述EMD分解、希尔伯特变换、谱方法
浅谈误码仪的使用场景?
Redhawk Dynamic Analysis
PyTorch 学习笔记 3 —— DATASETS & DATALOADERS & TRANSFORMS
Matlab 信号处理
ICC2(一)Preparing the Design
雷达成像 Matlab 仿真 3 —— 多目标检测
EXFO 730c optical time domain reflectometer only has IOLm optical eye to upgrade OTDR (open OTDR permission)
福禄克DSX2-5000 网络线缆测试仪为什么每年都要校准一次?
PLC的选型
Low power design isolation cell
set_clock_groups
Briefly introduce EMD decomposition, Hilbert transform and spectral method
Distinguishing PCB quality by color is a joke in itself
Triode design, understanding saturation, linear region and cut-off region