当前位置:网站首页>自定义学习率
自定义学习率
2022-07-27 18:27:00 【jstzwjr】
warm up + 自定义学习率策略(如cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
边栏推荐
- R语言使用dplyr包进行数据聚合统计计算滑动窗口统计值(Window Statistics)、计算滑动分组均值(mean)并合并生成的统计数据到原数据集中
- 【Dart】一门为跨端开发而生的编程语言
- 认识网络模型TCPIP模型
- SRE相关问题答疑
- go --- air自动重新编译
- 学术分享 | 清华大学 康重庆:电力系统碳计量技术与应用(Matlab代码实现)
- [numpy] array properties
- Kingbasees heterogeneous database migration guide (3. Kingbasees migration capability support system)
- Academic sharing | Tsinghua University, Kang Chongqing: power system carbon measurement technology and application (matlab code implementation)
- 认识网络模型网络模型概述
猜你喜欢

Download of MySQL driver jar package -- nanny tutorial

How to translate the address in the program?

Kingbasees heterogeneous database migration guide (4. Application migration process)
Rk3399 platform development series explanation (process part) 15.36, understanding process and collaboration process

Hexagon_V65_Programmers_Reference_Manual(5)

LabVIEW学习笔记五:按钮按下后无法返回原状

【深度学习】Pytorch Tensor 张量
![[numpy] array properties](/img/eb/a27c24deeb7951828cdfbaa88c059c.png)
[numpy] array properties
![[design tutorial] yolov7 target detection network interpretation](/img/dc/a795dcbd1163df6d8f33704a129d30.png)
[design tutorial] yolov7 target detection network interpretation

北京/上海/广州/深圳DAMA-CDGA/CDGP数据治理认证报名条件
随机推荐
Software test interview question: string "axbyczdj", if you get the result "ABCD"
DJI push code (one code for one use, updated on July 26, 2022)
未定义变量 “Lattice“ 或类 “Lattice.latticeEasy“(Matlab)
SLIM:自监督点云场景流与运动估计(ICCV 2021)
R语言使用dplyr包进行数据聚合统计计算滑动窗口统计值(Window Statistics)、计算滑动分组均值(mean)并合并生成的统计数据到原数据集中
Global styles and icons
When adding RTSP devices to easycvr platform, what is the reason for the phenomenon that they are all connected by TCP?
Slim: self supervised point cloud scene flow and motion estimation (iccv 2021)
How does the industrial switch enter the web management interface?
R语言使用dplyr包左连接两个dataframe数据(left join)
金仓数据库 KingbaseES异构数据库移植指南 (2. 概述)
R语言使用epiDisplay包的lroc函数可视化logistic回归模型的ROC曲线并输出诊断表(diagnostic table)、可视化多条ROC曲线、使用legend函数为可视化图像添加图例
adb shell ls /system/bin(索引表)
Hexagon_V65_Programmers_Reference_Manual(6)
Innovative cases | the growth strategy of digitalization of local life services and upgrading of Gaode brand
用户和权限限制用户使用资源
82.(cesium篇)cesium点在3d模型上运动
RK3399平台入门到精通系列讲解(导读篇)21天学习挑战介绍
MySQL design optimization generates columns
R语言dplyr包summarise_at函数计算dataframe数据中多个数据列(通过向量指定)的计数个数、均值和中位数、使用list函数指定函数列表(使用.符号和~符号指定函数语法purr)