当前位置:网站首页>Custom learning rate
Custom learning rate
2022-07-27 21:09:00 【jstzwjr】
warm up + Custom learning rate policy ( Such as cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
边栏推荐
- [program life] "stage summary" - unwilling to be ordinary
- NPDP|什么样的产品经理可以被称为优秀?
- Leetcode-136-a number that appears only once
- [Numpy] 广播机制(Broadcast)
- Qt OPenGL 光的漫反射
- R语言使用dplyr包左连接两个dataframe数据(left join)
- 最新版web漏洞扫描工具AppScan\AWVS\Xray安装及使用教程
- Academic sharing | Tsinghua University, Kang Chongqing: power system carbon measurement technology and application (matlab code implementation)
- 怎样实现文档协同?
- [Numpy] 数组索引和切片
猜你喜欢

如何让个性化推荐即刻触达?云原生数据库GaussDB(for Redis)来助力

Face recognition 5.1- insightface face face detection model training practice notes

NPDP | what kind of product manager can be called excellent?

NATAPP内网穿透工具外网访问个人项目

sscanf 导致地址越界

Installation and use tutorial of the latest version of Web vulnerability scanning tool appscan\awvs\xray
Rk3399 platform development series explanation (process part) 15.36, understanding process and collaboration process
![[Numpy] 数组索引和切片](/img/ce/34db7aef3fefe8a03e638d0838492f.png)
[Numpy] 数组索引和切片

PHP code audit 6 - file contains vulnerability

Ue5 uses DLSS (super sampling) to improve the FPS of the scene away from the optimization scheme of Caton
随机推荐
LeetCode每日一练 —— 21. 合并两个有序链表
搭建discuz论坛并攻破盗取数据库
LeetCode每日一练 —— 206. 反转链表
mcu日志输出的一种方法
Do you know about data synchronization?
Face recognition 5.1- insightface face face detection model training practice notes
一文读懂Plato&nbsp;Farm的ePLATO,以及其高溢价缘由
hcip第五天
Introduction to source insight 4.0
如何让个性化推荐即刻触达?云原生数据库GaussDB(for Redis)来助力
Force deduction solution summary 592 fraction addition and subtraction
Riding lantern case
R语言使用epiDisplay包的lroc函数可视化logistic回归模型的ROC曲线并输出诊断表(diagnostic table)、可视化多条ROC曲线、使用legend函数为可视化图像添加图例
Typroa 拼写检查: 缺少对于 中文 的字典文件
Go --- automatic recompilation of air
Automated testing ----- selenium (II)
Hexagon_V65_Programmers_Reference_Manual(9)
认识传输介质物理层概述
[Numpy] 数组索引和切片
JS closure knowledge