当前位置:网站首页>pytorch 笔记:validation ,model.eval V.S torch.no_grad
pytorch 笔记:validation ,model.eval V.S torch.no_grad
2022-06-30 10:03:00 【UQI-LIUWJ】
1 validation的一般框架
模型为model,优化器为optimizer
min_val_loss = np.inf
for epoch in range(1, epochs + 1):
############################训练部分开始#############################
model.train()
train_losses = []
for (batch_x, batch_y) in train_loader:
output = model(batch_x)
loss = criterion(output, batch_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#pytorch 老三样
############################训练部分结束#############################
############################验证部分开始#############################
model.eval()
for (batch_x, batch_y) in val_loader:
with torch.no_grad():
output = model(batch_x)
loss = criterion(output, batch_y)
val_losses.append(loss.item())
val_loss = np.mean(val_losses)
if val_loss < min_val_loss:
min_val_loss = val_loss
torch.save(model.state_dict(), save_path)
#保存最优模型
############################验证部分结束#############################test的时候,就可以加载这个最佳模型对应的参数(model.load_state_dict),进行测试
2 model.eval() V,S with torch.no_grad()
2.1 相同点
在PyTorch中进行validation时,使用他们均可切换到测试模式。
eg,对于dropout层和batchnorm层:
- 在train模式下,dropout网络层会按照设定的参数p设置保留激活单元的概率(保留概率=p); batchnorm层会继续计算数据的mean和var等参数并更新。
- 在val模式下,dropout层会让所有的激活单元都通过,而batchnorm层会停止计算和更新mean和var,直接使用在训练阶段已经学出的mean和var值。
2.2 不同点
- model.eval()仍然会涉及gradient的计算和存储,与training模式一样,只是不进行反向传播。
**with torch.zero_grad()**则停止autograd模块的工作,也就是停止gradient计算,以起到加速和节省显存的作用,从而节省了GPU算力和显存,但是并不会影响dropout和batchnorm层的行为。(这俩还是train模式下的样子)
——>二者搭配使用
边栏推荐
- GD32 RT-Thread PWM驱动函数
- Skill sorting [email protected]+ Alibaba cloud +nbiot+dht11+bh1750+ soil moisture sensor +oled
- Foresniffer tutorial: extracting data
- 超长干货 | Kubernetes命名空间详解
- Collectors.toMap应用
- CSDN blog operation team 2022 H1 summary
- 历史上的今天:微软收购 PowerPoint 开发商;SGI 和 MIPS 合并
- Skill sorting [email protected]+adxl345+ Motor vibration + serial port output
- 【深度学习】深度学习检测小目标常用方法
- The programmer was beaten.
猜你喜欢

机器学习面试准备(一)KNN

安徽《合肥市装配式建筑施工图审查设计深度要求》印发;河北衡水市调整装配式建筑预售许可标准

Harvester ch1 of CKB and HNS, connection tutorial analysis

Remember the experience of an internship. It is necessary to go to the pit (I)

ArcGIS Pro脚本工具(6)——修复CAD图层数据源

MySQL log management, backup and recovery of databases (1)

MySQL index, transaction and storage engine of database (3)

今晚19:00知识赋能第2期直播丨OpenHarmony智能家居项目之控制面板界面设计

The latest SCI impact factor release: the highest score of domestic journals is 46! Netizen: I understand if

MySQL advanced SQL statement of database (1)
随机推荐
最新SCI影响因子公布:国产期刊最高破46分!网友:算是把IF玩明白了
Harvester ch1 of CKB and HNS, connection tutorial analysis
IPhone address book import into Excel
Go -- maximum heap and minimum heap
前嗅ForeSpider教程:抽取数据
The latest SCI impact factor release: the highest score of domestic journals is 46! Netizen: I understand if
Skill sorting [email protected]+ Alibaba cloud +nbiot+dht11+bh1750+ soil moisture sensor +oled
JS FAQs
技能梳理[email protected]體感機械臂
Google 辟谣放弃 TensorFlow,它还活着!
Compétences Comb 27 @ Body sense Manipulator
Smith chart view of semi steel coaxial RF line and RF line matching calibration of network analyzer e5071c
Leetcode question brushing (II) -- sorting (go Implementation)
Auto Seg-Loss: 自动损失函数设计
How to deploy deflationary combustion destruction contract code in BSC chain_ Deploy dividend and marketing wallet contract code
半钢同轴射频线的史密斯圆图查看和网络分析仪E5071C的射频线匹配校准
Yixian e-commerce released its first quarterly report: adhere to R & D and brand investment to achieve sustainable and high-quality development
Koreano essential creates a professional style
"Kunming City coffee map" was opened again, and coffee brought the city closer
Jinbei LT6 is powerful in the year of the tiger, making waves