当前位置:网站首页>Pytorch中的梯度累加【在实验时,由于GPU显存限制,遇到batch_size不能再增大的情况。为解决该问题,使用梯度累加方法】
Pytorch中的梯度累加【在实验时,由于GPU显存限制,遇到batch_size不能再增大的情况。为解决该问题,使用梯度累加方法】
2022-06-12 22:47:00 【u013250861】
在实验时,由于GPU显存限制,遇到batch_size不能再增大的情况。为解决该问题,使用梯度累加方法。
不进行梯度累加的方法如下:
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2. backward
optimizer.zero_grad() # reset gradient
loss.backward()
optimizer.step()
使用梯度累加:
for i,(images,target) in enumerate(train_loader):
# 1. input output
images = images.cuda(non_blocking=True)
target = torch.from_numpy(np.array(target)).float().cuda(non_blocking=True)
outputs = model(images)
loss = criterion(outputs,target)
# 2.1 loss regularization
loss = loss/accumulation_steps
# 2.2 back propagation
loss.backward()
# 3. update parameters of net
if((i+1)%accumulation_steps)==0:
# optimizer the net
optimizer.step() # update parameters of net
optimizer.zero_grad() # reset gradient
原本batch size为32,使用梯度累加,设置accumulation_steps=4,此时只需将batch_size设置为8,就能达到之前的效果。
参考资料:
Pytorch中的梯度累加
边栏推荐
猜你喜欢

2022 heavyweight: growth law - skillfully use digital marketing to break through enterprise difficulties

Use js to listen for Keydown event

Flutter series part: detailed explanation of GridView layout commonly used in flutter

反走样/抗锯齿技术

The programmer dedicated to promoting VIM has left. Father of vim: I will dedicate version 9.0 to him

flutter系列之:flutter中常用的GridView layout详解

【建议收藏】通俗易懂图解网络知识-第一篇

Database daily question --- day 10: combine two tables

JVM foundation - > talk about class loader two parent delegation model

Colab tutorial (super detailed version) and colab pro/colab pro+ usage evaluation
随机推荐
The development trend of digital collections!
Insight into China's smart medical industry in 2022
Alcohol detector based on 51 single chip microcomputer
Use of map() function in JS
80 lines of code to realize simple rxjs
[Part 8] semaphore source code analysis and application details [key points]
Is it safe to open an account with new bonds? How should novices operate?
Coordinate transformation in pipelines
JVM Basics - > how to troubleshoot JVM problems in your project
Qt Quick 3D学习:使用鼠标键盘控制节点位置和方向
【LeetCode】33. Search rotation sort array
China's alternative sports equipment market trend report, technology dynamic innovation and market forecast
[Part VI] source code analysis and application details of countdownlatch [key]
接口测试工具apipost3.0版本对于流程测试和引用参数变量
Qrcodejs2 QR code generation JS
Mysql concat_ WS, concat function use
C语言:如何给全局变量起一个别名?
Go time format assignment
[leetcode] sword finger offer II 020 Number of palindrome substrings
【LeetCode】33. 搜索旋转排序数组