当前位置:网站首页>Custom learning rate
Custom learning rate
2022-07-27 21:09:00 【jstzwjr】
warm up + Custom learning rate policy ( Such as cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
边栏推荐
- NATAPP内网穿透工具外网访问个人项目
- R语言使用dplyr包左连接两个dataframe数据(left join)
- Hexagon_ V65_ Programmers_ Reference_ Manual(7)
- redis cook book.notes.
- Obtain website shell permission based on file upload vulnerability
- LeetCode每日一练 —— 21. 合并两个有序链表
- CPDA|如何拥有数据分析思维?
- How to solve the problem that tp6 controller does not exist: app\controller\index
- Codeforces 1706e merge + heuristic merge + st table
- 认识网络模型网络模型概述
猜你喜欢

PHP code audit 5 - XSS vulnerability

go --- air自动重新编译

Face recognition 5.1- insightface face face detection model training practice notes

CPDA | how to have data analysis thinking?

Knowledge management system promotes the development of enterprise informatization

NPDP|什么样的产品经理可以被称为优秀?

Typroa 拼写检查: 缺少对于 中文 的字典文件

五大知名人士对于AI的忧虑
![[program life]](/img/68/9a337d37490ad5cd75095cc5efd5e4.jpg)
[program life] "stage summary" - unwilling to be ordinary

Global styles and icons
随机推荐
Hexagon_ V65_ Programmers_ Reference_ Manual(6)
认识传输介质通信方式
Hcip day 5
PHP代码审计5—XSS漏洞
What are the application scenarios of real name authentication in the cultural tourism industry?
How to make personalized recommendations instantly accessible? Cloud native database gaussdb (for redis) to help
Recommend a powerful search tool listary
Hexagon_V65_Programmers_Reference_Manual(8)
Do you know about data synchronization?
认识传输介质物理层概述
PHP code audit 6 - file contains vulnerability
NPDP|什么样的产品经理可以被称为优秀?
【R语言】【1】初学R语言语法使用Rstudio编辑
命令行 PDF 转换器:::fCoder 2PDF
AIRIOT答疑第6期|如何使用二次开发引擎?
Qt OPenGL 光的漫反射
Airiot Q & A issue 6 | how to use the secondary development engine?
自动化测试----selenium(二)
Obtain website shell permission based on file upload vulnerability
Uncaught SyntaxError: redeclaration of let page