当前位置:网站首页>Pytorch(六) —— 模型调优tricks
Pytorch(六) —— 模型调优tricks
2022-07-07 04:44:00 【CyrusMay】
Pytorch(六) —— 模型调优tricks
1.正则化 Regularization
1.1 L1正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss
1.2 L2正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)
2 动量与学习率衰减
2.1 momentum
opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)
2.2 learning rate tunning
- torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
- torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
loss_val = train(...)
lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
lr_scheduler.step() # 监听loss
train(...)
3. Early Stopping
4. Dropout
model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)
by CyrusMay 2022 07 03
边栏推荐
- Open source ecosystem | create a vibrant open source community and jointly build a new open source ecosystem!
- Resource create package method
- C语言二叉树与建堆
- [UTCTF2020]file header
- LeetCode 90:子集 II
- 2022 recurrent training question bank and answers of refrigeration and air conditioning equipment operation
- 【obs】win-capture需要winrt
- Wechat applet data binding multiple data
- Ansible
- [advanced digital IC Verification] command query method and common command interpretation of VCs tool
猜你喜欢
【webrtc】m98 screen和window采集
有 Docker 谁还在自己本地安装 Mysql ?
Leetcode 90: subset II
QT learning 28 toolbar in the main window
Linux server development, SQL statements, indexes, views, stored procedures, triggers
2022 welder (elementary) judgment questions and online simulation examination
You Li takes you to talk about C language 6 (common keywords)
A bit of knowledge - about Apple Certified MFI
2022茶艺师(初级)考试题模拟考试题库及在线模拟考试
mysql多列索引(组合索引)特点和使用场景
随机推荐
2022制冷与空调设备运行操作复训题库及答案
php导出百万数据
Cnopendata geographical distribution data of religious places in China
Qt学习26 布局管理综合实例
C语言二叉树与建堆
C language flight booking system
misc ez_ usb
【数字IC验证快速入门】17、SystemVerilog学习之基本语法4(随机化Randomization)
pytest+allure+jenkins环境--填坑完毕
Detailed explanation of Kalman filter for motion state estimation
Li Kou interview question 04.01 Path between nodes
[UVM basics] summary of important knowledge points of "UVM practice" (continuous update...)
Chip design data download
[Stanford Jiwang cs144 project] lab4: tcpconnection
Thinkcmf6.0 installation tutorial
Pytest+allure+jenkins installation problem: pytest: error: unrecognized arguments: --alluredir
dash plotly
Linux server development, redis source code storage principle and data model
Sign up now | oar hacker marathon phase III, waiting for your challenge
buuctf misc USB