当前位置:网站首页>Pytorch(六) —— 模型调优tricks
Pytorch(六) —— 模型调优tricks
2022-07-07 04:44:00 【CyrusMay】
Pytorch(六) —— 模型调优tricks
1.正则化 Regularization
1.1 L1正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
loss_classify = nn.CrossEntropyLoss().to(device)
# L1范数
l1_loss = 0
for param in MLP.parameters():
l1_loss += torch.sum(torch.abs(param))
loss = loss_classify+l1_loss
1.2 L2正则化
import torch
import torch.nn.functional as F
from torch import nn
device=torch.device("cuda:0")
MLP = nn.Sequential(nn.Linear(128,64),
nn.ReLU(inplace=True),
nn.Linear(64,32),
nn.ReLU(inplace=True),
nn.Linear(32,10)
)
MLP.to(device)
# L2范数
opt = torch.optim.SGD(MLP.parameters(),lr=0.001,weight_decay=0.1) # 通过weight_decay实现L2
loss = nn.CrossEntropyLoss().to(device)
2 动量与学习率衰减
2.1 momentum
opt = torch.optim.SGD(model.parameters(),lr=0.001,momentum=0.78,weight_decay=0.1)
2.2 learning rate tunning
- torch.optim.lr_scheduler.ReduceLROnPlateau() 当损失函数值不降低时使用
- torch.optim.lr_scheduler.StepLR() 按照一定步数降低学习率
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=opt,mode="min",factor=0.1,patience=10)
for epoch in torch.arange(1000):
loss_val = train(...)
lr_scheduler.step(loss_val) # 监听loss
opt = torch.optim.SGD(net.parameters(),lr=1)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer=opt,step_size=30,gamma=0.1)
for epoch in torch.arange(1000):
lr_scheduler.step() # 监听loss
train(...)
3. Early Stopping
4. Dropout
model = nn.Sequential(
nn.Linear(256,128),
nn.Dropout(p=0.5),
nn.ReLu(),
)
by CyrusMay 2022 07 03
边栏推荐
- LeetCode 40:组合总和 II
- Technology cloud report: from robot to Cobot, human-computer integration is creating an era
- Custom class loader loads network class
- Es FAQ summary
- Regular e-commerce problems part1
- Installing postgresql11 database under centos7
- Linux server development, MySQL cache strategy
- json 数据展平pd.json_normalize
- Leetcode 90: subset II
- 2022茶艺师(初级)考试题模拟考试题库及在线模拟考试
猜你喜欢
Main window in QT learning 27 application
A bit of knowledge - about Apple Certified MFI
开源生态|打造活力开源社区,共建开源新生态!
2022 welder (elementary) judgment questions and online simulation examination
[UTCTF2020]file header
LeetCode 90:子集 II
PHP exports millions of data
[Matlab] Simulink 自定义函数中的矩阵乘法工作不正常时可以使用模块库中的矩阵乘法模块代替
[Stanford Jiwang cs144 project] lab4: tcpconnection
【webrtc】m98 screen和window采集
随机推荐
Linux server development, redis source code storage principle and data model
Mysql高低版本切换需要修改的配置5-8(此处以aicode为例)
C language flight booking system
Pytest + allure + Jenkins Environment - - achèvement du remplissage de la fosse
Linux server development, detailed explanation of redis related commands and their principles
Introduction to basic components of wechat applet
C语言航班订票系统
[webrtc] m98 Screen and Window Collection
The charm of SQL optimization! From 30248s to 0.001s
大视频文件的缓冲播放原理以及实现
2022年茶艺师(中级)考试试题及模拟考试
通信设备商,到底有哪些岗位?
[quick start of Digital IC Verification] 15. Basic syntax of SystemVerilog learning 2 (operators, type conversion, loops, task/function... Including practical exercises)
C语言通信行程卡后台系统
Qt学习27 应用程序中的主窗口
pytest+allure+jenkins安装问题:pytest: error: unrecognized arguments: --alluredir
【VHDL 并行语句执行】
CentOS7下安装PostgreSQL11数据库
numpy中dot函数使用与解析
Who has docker to install MySQL locally?