当前位置:网站首页>pytorch 笔记:validation ,model.eval V.S torch.no_grad
pytorch 笔记:validation ,model.eval V.S torch.no_grad
2022-06-30 10:03:00 【UQI-LIUWJ】
1 validation的一般框架
模型为model,优化器为optimizer
min_val_loss = np.inf
for epoch in range(1, epochs + 1):
############################训练部分开始#############################
model.train()
train_losses = []
for (batch_x, batch_y) in train_loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#pytorch 老三样
############################训练部分结束#############################
############################验证部分开始#############################
model.eval()
for (batch_x, batch_y) in val_loader:
with torch.no_grad():
output = model(batch_x)
loss = criterion(output, batch_y)
val_losses.append(loss.item())
val_loss = np.mean(val_losses)
if val_loss < min_val_loss:
min_val_loss = val_loss
torch.save(model.state_dict(), save_path)
#保存最优模型
############################验证部分结束#############################test的时候,就可以加载这个最佳模型对应的参数(model.load_state_dict),进行测试
2 model.eval() V,S with torch.no_grad()
2.1 相同点
在PyTorch中进行validation时,使用他们均可切换到测试模式。
eg,对于dropout层和batchnorm层:
- 在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
- 在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2.2 不同点
- model.eval()仍然会涉及gradient的计算和存储,与training模式一样,只是不进行反向传播。
**with torch.zero_grad()**则停止autograd模块的工作,也就是停止gradient计算,以起到加速和节省显存的作用,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。(这俩还是train模式下的样子)
——>二者搭配使用
边栏推荐
- 敏捷开发: 超级易用水桶估计系统
- Skill sorting [email protected]+ Alibaba cloud +nbiot+dht11+bh1750+ soil moisture sensor +oled
- 前嗅ForeSpider教程:抽取数据
- Notes on numerical calculation - iterative solution of linear equations
- Auto SEG loss: automatic loss function design
- 如何解决跨域
- GD32 RT-Thread PWM驱动函数
- MySQL advanced SQL statement of database (2)
- 【Rust日报】2021-01-22 首份Rust月刊杂志邀请大家一起参与
- Highlight display of Jinbei LB box, adhering to mini special effects
猜你喜欢

移植完整版RT-Thread到GD32F4XX(详细)

RobotFramework学习笔记:环境安装以及robotframework-browser插件的安装

19:00 p.m. tonight, knowledge empowerment phase 2 live broadcast - control panel interface design of openharmony smart home project

Anhui "requirements for design depth of Hefei fabricated building construction drawing review" was printed and distributed; Hebei Hengshui city adjusts the pre-sale license standard for prefabricated

Dow Jones Industrial Average

CSDN blog operation team 2022 H1 summary

scratch绘制正方形 电子学会图形化编程scratch等级考试二级真题和答案解析2022年6月

【深度学习】深度学习检测小目标常用方法
![[deep learning] common methods for deep learning to detect small targets](/img/c6/8f0549864992a1554397bad16dad4d.jpg)
[deep learning] common methods for deep learning to detect small targets

MySQL log management, backup and recovery of databases (2)
随机推荐
Skill combing [email protected] control a dog's running on OLED
The performance of arm's new CPU has been improved by 22%, up to 12 cores can be combined, and the GPU is first equipped with hardware optical tracking. Netizen: the gap with apple is growing
Implementation of monitor program with assembly language
I found a wave of "alchemy artifact" in the goose factory. The developer should pack it quickly
js常见问题
Yixian e-commerce released its first quarterly report: adhere to R & D and brand investment to achieve sustainable and high-quality development
ArcGIS Pro scripting tool (6) -- repairing CAD layer data sources
R语言aov函数进行重复测量方差分析(Repeated measures ANOVA、其中一个组内因素和一个组间因素)、分别使用interaction.plot函数和boxplot对交互作用进行可视化
安徽《合肥市装配式建筑施工图审查设计深度要求》印发;河北衡水市调整装配式建筑预售许可标准
CSDN博客运营团队2022年H1总结
苹果5G芯片被曝研发失败,QQ密码bug引热议,蔚来回应做空传闻,今日更多大新闻在此...
IPhone address book import into Excel
nvm、nrm、npx使用(安装、基本命令、参数、curl、wget)
六月集训(第30天) —— 拓扑排序
逸仙電商發布一季報:堅持研發及品牌投入,實現可持續高質量發展
CSDN blog operation team 2022 H1 summary
JS FAQs
MySQL advanced SQL statement of database (1)
最新SCI影响因子公布:国产期刊最高破46分!网友:算是把IF玩明白了
How to deploy deflationary combustion destruction contract code in BSC chain_ Deploy dividend and marketing wallet contract code