当前位置:网站首页>Pytorch(六) —— 模型调优tricks
Pytorch(六) —— 模型调优tricks
2022-07-07 04:44:00 【CyrusMay】
Pytorch(六) —— 模型调优tricks
1.正则化 Regularization
1.1 L1正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss
1.2 L2正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)
2 动量与学习率衰减
2.1 momentum
opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)
2.2 learning rate tunning
- torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
- torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
loss_val = train(...)
lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
lr_scheduler.step() # 监听loss
train(...)
3. Early Stopping
4. Dropout
model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)
by CyrusMay 2022 07 03
边栏推荐
- 【數字IC驗證快速入門】15、SystemVerilog學習之基本語法2(操作符、類型轉換、循環、Task/Function...內含實踐練習)
- Technology cloud report: from robot to Cobot, human-computer integration is creating an era
- What are the positions of communication equipment manufacturers?
- [unity] several ideas about circular motion of objects
- Cnopendata American Golden Globe Award winning data
- pytest+allure+jenkins环境--填坑完毕
- Problem solving: unable to connect to redis
- Open source ecosystem | create a vibrant open source community and jointly build a new open source ecosystem!
- misc ez_ usb
- [UVM basics] summary of important knowledge points of "UVM practice" (continuous update...)
猜你喜欢

Resource create package method
![[Stanford Jiwang cs144 project] lab3: tcpsender](/img/82/5f99296764937e7d119b8ab22828fd.png)
[Stanford Jiwang cs144 project] lab3: tcpsender

【webrtc】m98 screen和window采集

2022 simulated examination question bank and online simulated examination of tea master (primary) examination questions

Numbers that appear only once

Padavan manually installs PHP

LeetCode 40:组合总和 II

Who has docker to install MySQL locally?

Ansible

nacos
随机推荐
Cnopendata list data of Chinese colleges and Universities
C语言队列
Ansible
pytest+allure+jenkins環境--填坑完畢
[experience sharing] how to expand the cloud service icon for Visio
Qt学习28 主窗口中的工具栏
Who has docker to install MySQL locally?
Linux server development, redis protocol and asynchronous mode
IO stream file
Linux server development, MySQL index principle and optimization
@component(““)
LeetCode 90:子集 II
The charm of SQL optimization! From 30248s to 0.001s
Common validation comments
pytest+allure+jenkins环境--填坑完毕
Live online system source code, using valueanimator to achieve view zoom in and out animation effect
【webrtc】m98 screen和window采集
Installing postgresql11 database under centos7
Visualization Document Feb 12 16:42
C language flight booking system