当前位置:网站首页>Pytorch learning note 4 - automatic calculation of gradient descent autograd
Pytorch learning note 4 - automatic calculation of gradient descent autograd
2022-07-28 06:28:00 【I have two candies】
List of articles
1. Tensors, Functions and Computational graph
torch.autograd It can automatically calculate the derivative of each element in the calculation diagram , For example, in the following calculation diagram w and b The derivative of is calculated by :

It can be realized in this way :
import torch
x = torch.ones(5)
y = torch.zeros(3)
w = torch.randn(5, 3, requires_grad=True)
b = torch.randn(3, requires_grad=True)
z = torch.matmul(x, w) + b
# or (z = x.matmul(w) + b)
loss = torch.nn.functional.binary_cross_entropy_with_logits(z, y)
Parameters requires_grad=True Specify automatic differential derivation when performing forward calculation , After calculating Z and loss after ,tensor z and loss An attribute will be automatically created grad_fn, It's a Function object , Used to calculate forward propagation and back propagation Derivation .
print(f'Gradient function for z = {
z.grad_fn}')
print(f'Gradient function for loss = {
loss.grad_fn}')
2. Computing Gradients
By calling loss.backward() Can do it once BP, It can be calculated automatically ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss and ∂ l o s s ∂ w \frac{\partial{loss}}{\partial{w}} ∂w∂loss, adopt w.grad and b.grad obtain
loss Yes w and b Differential of :
loss.backward()
print(w.grad)
print(b.grad)
# tensor([[0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266],
# [0.1814, 0.0460, 0.3266]])
# tensor([0.1814, 0.0460, 0.3266])
Be careful
1 . We can only obtain the grad properties for the leaf nodes of the computational graph, which have requires_grad property set to True. For all other nodes in our graph, gradients will not be available.
2 . We can only perform gradient calculations using backward once on a given graph, for performance reasons. If we need to do several backward calls on the same graph, we need to pass retain_graph=True to the backward call.
3. Disabling Gradient Tracking
All have requires_grad=True Of tensors Will record the calculation process and support gradient calculation , But sometimes we don't want to calculate gradient information , For example, when we point to let the model predict some samples , You can use torch.no_grad() To stop all gradient calculations :
z = torch.matmul(x, w) + b
print(z.requires_grad) # True
with torch.no_grad():
z = torch.matmul(x, w) + b
print(z.requires_grad) # False
Besides , You can also use z.detach() :
z = torch.matmul(x, w) + b
print(z.detach().requires_grad) # False
When judging the performance of the model , Need to be in with torch.no_grad() Under certain conditions !
Other scenarios :frozen Part of the network 、 Fine tuning the model 、 Accelerate forward propagation
4. forward & backward
Forward propagation
1 . Calculate the result of forward propagation , Save every operation Of grad_fn
Back propagation
1 . Calculate each .grad_fn Corresponding gradient value
2 . Add the gradient value to the corresponding tensor Of .grad Attribute
3 . Use the chain rule to calculate the gradient of leaf nodes
REFERENCE:
1 . https://pytorch.org/tutorials/beginner/basics/autogradqs_tutorial.html#disabling-gradient-tracking
For more information, please refer to :PyTorch Learning notes
边栏推荐
- ICC2(三)Clock Tree Synthesis
- CLIP Learning Transferable Visual Models From Natural Language Supervision
- Matlab 信号处理
- AEM online product promotion conference - Cable certification tester
- 毕业论文 | 文献综述应该怎么写
- In asp Usage of cookies in. Net
- ClickHouse 中的公共表表达式CTE
- EfficientNET_ V1
- 保研面试中常见的英语问题有哪些?
- Arduino reads the analog voltage_ How mq2 gas / smoke sensor works and its interface with Arduino
猜你喜欢

PLC的整体认识

Fluke dtx-1800 and its accessories dtx-cha002 channel adapter channel replacement RJ45 socket notes

set_false_path

Pycharm2019 set editor theme and default code

论文神器 VS Code + LaTex + LaTex Workshop

Efficient Net_V2

ConNeXt

测量电脑电池容量

When to replace jack socket for dsx-pc6 jumper module?

EfficientNET_ V1
随机推荐
PyTorch 学习笔记 3 —— DATASETS & DATALOADERS & TRANSFORMS
What is the AEM testpro cv100 and fluke dsx-8000 of category 8 network cable tester?
t-SNE降维可视化
VS Code 基础配置与美化
Beta distribution (probability of probability)
Mae mask self encoding is scalable learning
Perl入门学习(十)格式化输出
Exploration of Clickhouse aggregation internal mechanism of aggregation
CString to char[] function
Low power design -power switch
MATLAB signal processing
低功耗设计-Power Switch
Bag of Tricks训练卷积网络的技巧
福禄克DSX2-5000、DSX2-8000模块如何找到校准到期日期?
T-sne dimension reduction visualization
【YOLOv5】环境搭建:Win11 + mx450
Transformer 自注意力机制 及完整代码实现
CLIP Learning Transferable Visual Models From Natural Language Supervision
Design and analysis of contactor coil control circuit
When to replace jack socket for dsx-pc6 jumper module?