当前位置:网站首页>Pytorch(六) —— 模型调优tricks
Pytorch(六) —— 模型调优tricks
2022-07-07 04:44:00 【CyrusMay】
Pytorch(六) —— 模型调优tricks
1.正则化 Regularization
1.1 L1正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss
1.2 L2正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)
2 动量与学习率衰减
2.1 momentum
opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)
2.2 learning rate tunning
- torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
- torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
loss_val = train(...)
lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
lr_scheduler.step() # 监听loss
train(...)
3. Early Stopping
4. Dropout
model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)
by CyrusMay 2022 07 03
边栏推荐
- Li Kou interview question 04.01 Path between nodes
- Ansible
- Codeforces Global Round 19
- [quick start of Digital IC Verification] 15. Basic syntax of SystemVerilog learning 2 (operators, type conversion, loops, task/function... Including practical exercises)
- You Li takes you to talk about C language 6 (common keywords)
- Cnopendata list data of Chinese colleges and Universities
- Linux server development, MySQL index principle and optimization
- [UVM practice] Chapter 2: a simple UVM verification platform (2) only driver verification platform
- Info | webrtc M97 update
- misc ez_usb
猜你喜欢

Operation suggestions for today's spot Silver

Use and analysis of dot function in numpy
![[webrtc] m98 Screen and Window Collection](/img/b1/1ca13b6d3fdbf18ff5205ed5584eef.png)
[webrtc] m98 Screen and Window Collection

A bit of knowledge - about Apple Certified MFI
![[webrtc] M98 screen and window acquisition](/img/b1/1ca13b6d3fdbf18ff5205ed5584eef.png)
[webrtc] M98 screen and window acquisition

misc ez_usb

Linux server development, MySQL index principle and optimization

json 数据展平pd.json_normalize

微信小程序基本组件使用介绍

Open source ecosystem | create a vibrant open source community and jointly build a new open source ecosystem!
随机推荐
Numbers that appear only once
Problem solving: unable to connect to redis
探索干货篇!Apifox 建设思路
The configuration that needs to be modified when switching between high and low versions of MySQL 5-8 (take aicode as an example here)
Wechat applet data binding multiple data
IO stream file
What are the positions of communication equipment manufacturers?
Linux server development, MySQL stored procedures, functions and triggers
【數字IC驗證快速入門】15、SystemVerilog學習之基本語法2(操作符、類型轉換、循環、Task/Function...內含實踐練習)
Quickly use Jacobo code coverage statistics
2022年全国最新消防设施操作员(初级消防设施操作员)模拟题及答案
Chip design data download
2022 simulated examination question bank and online simulated examination of tea master (primary) examination questions
Custom class loader loads network class
Live broadcast platform source code, foldable menu bar
Explore Cassandra's decentralized distributed architecture
Hands on deep learning (IV) -- convolutional neural network CNN
Force buckle 145 Binary Tree Postorder Traversal
Qt学习28 主窗口中的工具栏
C语言队列