当前位置:网站首页>Why should the gradient be manually cleared before back propagation in pytorch?
Why should the gradient be manually cleared before back propagation in pytorch?
2022-07-03 18:58:00 【Xiaobai learns vision】
Click on the above “ Xiaobai studies vision ”, Optional plus " Star standard " or “ Roof placement ”
Heavy dry goods , First time delivery edit : Recollection
https://www.zhihu.com/question/303070254
This article is only for academic sharing , If infringement , Can delete text processing
PyTorch Why do you need to reset the gradient manually before backpropagation ?
author :Pascal
https://www.zhihu.com/question/303070254/answer/573037166
This mode allows gradients to play more tricks , For example, gradient accumulation (gradient accumulation)
Traditional training function , One batch That's how you train :
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2. backward
optimizer.zero_grad() # reset gradient
loss.backward()
optimizer.step()obtain loss: Input images and tags , adopt infer The predicted value is calculated , Calculate the loss function ;
optimizer.zero_grad() Clear past gradients ;
loss.backward() Back propagation , Calculate the current gradient ;
optimizer.step() Update the network parameters according to the gradient
In a nutshell, it's just a batch The data of , Calculate the primary gradient , Update the network
Using gradient accumulation is written like this :
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2.1 loss regularization
loss = loss/accumulation_steps
# 2.2 back propagation
loss.backward()
# 3. update parameters of net
if((i+1)%accumulation_steps)==0:
# optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradientobtain loss: Input images and tags , adopt infer The predicted value is calculated , Calculate the loss function ;
loss.backward() Back propagation , Calculate the current gradient ;
Multiple cycle steps 1-2, Don't empty the gradient , The gradient has been accumulated on the gradient ;
After the gradient has accumulated for a certain number of times , First optimizer.step() Update the network parameters according to the cumulative gradient , then optimizer.zero_grad() Clear past gradients , Prepare for the next wave of gradient accumulation ;
In conclusion : Gradient accumulation is , Every time to get 1 individual batch The data of , Calculation 1 Sub gradient , Gradient doesn't empty , Keep accumulating , Add up a certain number of times , Update the network parameters according to the accumulated gradient , And then empty the gradient , So let's do the next loop .
Under certain conditions ,batchsize The bigger the training, the better , Gradient accumulation realizes batchsize In disguise , If accumulation_steps by 8, be batchsize ' Disguised form ' Expanded 8 times , It is a good way for our beggar lab to solve the problem of limited video memory trick, Attention should be paid when using , The learning rate should also be appropriately enlarged .
to update 1: About BN Does it have an impact , That's what someone said before :
As far as I know, batch norm statistics get updated on each forward pass, so no problem if you don't do .backward() every time.
BN Our estimate is in forward Stage has been completed , Not conflict , It's just accumulation_steps=8 And the real batchsize Eight times larger than , The effect is naturally worse , After all, eight times Batchsize Of BN The estimated mean and variance must be more accurate .
to update 2: according to
Shao Hua Li
The share of , It can be lowered BN Their own momentum Parameters
bn I have a momentum Parameters : x_new_running = (1 - momentum) * x_running + momentum * x_new_observed. momentum The closer the 0, old running stats The longer you remember , So you can get a longer series of Statistics
I took a brief look at PyTorch 1.0 Source code :https://github.com/pytorch/pytorch/blob/162ad945902e8fc9420cbd0ed432252bd7de673a/torch/nn/modules/batchnorm.py#L24,BN In class momentum This property defaults to 0.1, You can try to adjust it .
author :Forever123
https://www.zhihu.com/question/303070254/answer/608153308
The reason lies in PyTorch in , The calculated gradient value will be accumulated
This benefit can be seen from the perspective of memory consumption
1. Edition1
stay PyTorch in ,multi-task The task is a standard train from scratch The process for
for idx, data in enumerate(train_loader):
xs, ys = data
pred1 = model1(xs)
pred2 = model2(xs)
loss1 = loss_fn1(pred1, ys)
loss2 = loss_fn2(pred2, ys)
** ** **
loss = loss1 + loss2
optmizer.zero_grad()
loss.backward()
++++++
optmizer.step()from PyTorch In terms of design principle , In each forward calculation we get pred when , Meeting Generate a calculation diagram for gradient retransmission , This picture is stored and carried on back propagation Intermediate results needed , When calling the .backward() after , The graph will be released from memory
The above code executes to ****** when , The memory contains two calculation diagrams , And as you sum, you get loss, The two figures are merged , And the change in size can be ignored
Execute to ++++++ when , Get the corresponding grad Value and free memory . such , Two calculation charts must be stored during training , And if the loss The source composition of is more complex , Memory consumption will be greater
2. Edition2
In order to reduce the memory consumption each time , With the help of gradient accumulation , And then there is
, There are the following variants
for idx, data in enumerate(train_loader):
xs, ys = data
optmizer.zero_grad()
# Calculation d(l1)/d(x)
pred1 = model1(xs) # Generate graph1
loss = loss_fn1(pred1, ys)
loss.backward() # Release graph1
# Calculation d(l2)/d(x)
pred2 = model2(xs) # Generate graph2
loss2 = loss_fn2(pred2, ys)
loss.backward() # Release graph2
# Use d(l1)/d(x)+d(l2)/d(x) To optimize
optmizer.step()As you can see from the code , Using gradient accumulation , It can be done with up to one calculation chart saved multi-task Task training .
Another reason is to stack multiple... When the memory size is not enough batch Of grad As a big batch To iterate , Because the gradient obtained by the two is equivalent
To sum up, we can see that , The idea of gradient accumulation is very memory friendly , By FAIR From the design concept of .
author :blateyang
https://www.zhihu.com/question/303070254/answer/535552845
The simple reason is that PyTorch By default, the gradient will be accumulated .
As for why PyTorch It has such characteristics , The explanation found on the Internet is due to PyTorch And autograd The mechanism makes it very flexible , It also means that you can get the gradient of a tensor , And then we use that gradient again , Then we can recalculate the gradient of the new operation , There is no definite point for when to stop the forward operation . So automatically set the gradient to 0 It's tricky , Because you don't know when a calculation will end and when there will be a new start .
The advantage of default accumulation is that when you share the previous part in multitasking tensor After many calculations , Call different tasks loss Of backward, those tensor The gradient of will be automatically accumulated , The disadvantage is that when you don't want the previous gradient to affect the calculation of the current gradient, you need to manually clear it .
The good news !
Xiaobai learns visual knowledge about the planet
Open to the outside world

download 1:OpenCV-Contrib Chinese version of extension module
stay 「 Xiaobai studies vision 」 Official account back office reply : Extension module Chinese course , You can download the first copy of the whole network OpenCV Extension module tutorial Chinese version , Cover expansion module installation 、SFM Algorithm 、 Stereo vision 、 Target tracking 、 Biological vision 、 Super resolution processing and other more than 20 chapters .
download 2:Python Visual combat project 52 speak
stay 「 Xiaobai studies vision 」 Official account back office reply :Python Visual combat project , You can download, including image segmentation 、 Mask detection 、 Lane line detection 、 Vehicle count 、 Add Eyeliner 、 License plate recognition 、 Character recognition 、 Emotional tests 、 Text content extraction 、 Face recognition, etc 31 A visual combat project , Help fast school computer vision .
download 3:OpenCV Actual project 20 speak
stay 「 Xiaobai studies vision 」 Official account back office reply :OpenCV Actual project 20 speak , You can download the 20 Based on OpenCV Realization 20 A real project , Realization OpenCV Learn advanced .
Communication group
Welcome to join the official account reader group to communicate with your colleagues , There are SLAM、 3 d visual 、 sensor 、 Autopilot 、 Computational photography 、 testing 、 Division 、 distinguish 、 Medical imaging 、GAN、 Wechat groups such as algorithm competition ( It will be subdivided gradually in the future ), Please scan the following micro signal clustering , remarks :” nickname + School / company + Research direction “, for example :” Zhang San + Shanghai Jiaotong University + Vision SLAM“. Please note... According to the format , Otherwise, it will not pass . After successful addition, they will be invited to relevant wechat groups according to the research direction . Please do not send ads in the group , Or you'll be invited out , Thanks for your understanding ~边栏推荐
- 我們做了一個智能零售結算平臺
- KINGS
- 简述服务量化分析体系
- Suffix derivation based on query object fields
- SSH 远程执行命令简介
- In addition to the prickles that pierce your skin, there are poems and distant places that originally haunt you in plain life
- 2022.02.11
- Sepconv (separable revolution) code recurrence
- Recommend a simple browser tab
- Kratos微服务框架下实现CQRS架构模式
猜你喜欢

How to quickly view the inheritance methods of existing models in torchvision?

【光学】基于matlab涡旋光产生【含Matlab源码 1927期】
![Failed to start component [StandardEngine[Catalina]. StandardHost[localhost]. StandardContext](/img/56/ea61359dd149a49589ba7ad70812a0.jpg)
Failed to start component [StandardEngine[Catalina]. StandardHost[localhost]. StandardContext

SQL: special update operation

leetcode:556. 下一个更大元素 III【模拟 + 尽可能少变更】

Flutter网络和数据存储框架搭建 -b1

2022.02.11

【水质预测】基于matlab模糊神经网络水质预测【含Matlab源码 1923期】

【疾病识别】基于matlab GUI机器视觉肺癌检测系统【含Matlab源码 1922期】

There are several levels of personal income tax
随机推荐
[leetcode周赛]第300场——6110. 网格图中递增路径的数目-较难
Help change the socket position of PCB part
DriveSeg:动态驾驶场景分割数据集
Mysql45 lecture learning notes (II)
为什么要做特征的归一化/标准化?
The installation path cannot be selected when installing MySQL 8.0.23
EGO Planner代码解析bspline_optimizer部分(1)
leetcode:11. Container with the most water [double pointer + greed + remove the shortest board]
Web3 credential network project galaxy is better than nym?
“google is not defined” when using Google Maps V3 in Firefox remotely
Dynamic planning -- expansion topics
组策略中开机脚本与登录脚本所使用的用户身份
Succession of flutter
PyTorch中在反向传播前为什么要手动将梯度清零?
my. INI file not found
Scrapy爬虫框架
Processing of user input parameters in shell script
Failed to start component [StandardEngine[Catalina]. StandardHost[localhost]. StandardContext
Reading a line from ifstream into a string variable
Which do MySQL and Oracle learn?