当前位置:网站首页>Pytorch(六) —— 模型调优tricks
Pytorch(六) —— 模型调优tricks
2022-07-07 04:44:00 【CyrusMay】
Pytorch(六) —— 模型调优tricks
1.正则化 Regularization
1.1 L1正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss
1.2 L2正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)
2 动量与学习率衰减
2.1 momentum
opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)
2.2 learning rate tunning
- torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
- torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
loss_val = train(...)
lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
lr_scheduler.step() # 监听loss
train(...)
3. Early Stopping
4. Dropout
model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)
by CyrusMay 2022 07 03
边栏推荐
- Linux server development, redis protocol and asynchronous mode
- C语言航班订票系统
- [quickstart to Digital IC Validation] 15. Basic syntax for SystemVerilog Learning 2 (operator, type conversion, loop, Task / Function... Including practical exercises)
- [UVM foundation] what is transaction
- Linux server development, detailed explanation of redis related commands and their principles
- padavan手动安装php
- Quickly use Jacobo code coverage statistics
- Numbers that appear only once
- Redis technology leak detection and filling (II) - expired deletion strategy
- Linux server development, MySQL cache strategy
猜你喜欢
![[mathematical notes] radian](/img/43/2af510adb24fe46fc0033d11d60488.jpg)
[mathematical notes] radian

LeetCode 40:组合总和 II

Leetcode 90: subset II

json 数据展平pd.json_normalize

Li Kou interview question 04.01 Path between nodes

Codeforces Global Round 19

Linux server development, redis protocol and asynchronous mode

numpy中dot函数使用与解析

2022 simulated examination question bank and online simulated examination of tea master (primary) examination questions

Common validation comments
随机推荐
Linux server development, detailed explanation of redis related commands and their principles
Idea add class annotation template and method template
2022 National latest fire-fighting facility operator (primary fire-fighting facility operator) simulation questions and answers
Pytest+allure+jenkins installation problem: pytest: error: unrecognized arguments: --alluredir
C语言队列
Padavan manually installs PHP
2022 Inner Mongolia latest advanced fire facility operator simulation examination question bank and answers
Codeforces Global Round 19
What are the positions of communication equipment manufacturers?
LeetCode 40:组合总和 II
Visualization Document Feb 12 16:42
Main window in QT learning 27 application
[webrtc] M98 screen and window acquisition
CTF daily question day43 rsa5
pytest+allure+jenkins环境--填坑完毕
What is the interval in gatk4??
Few-Shot Learning && Meta Learning:小样本学习原理和Siamese网络结构(一)
Linux server development, MySQL transaction principle analysis
[VHDL parallel statement execution]
leanote私有云笔记搭建