当前位置:网站首页>Custom learning rate
Custom learning rate
2022-07-27 21:09:00 【jstzwjr】
warm up + Custom learning rate policy ( Such as cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
边栏推荐
- [dart] a programming language for cross end development
- R语言使用dplyr包进行数据聚合统计计算滑动窗口统计值(Window Statistics)、计算滑动分组均值(mean)并合并生成的统计数据到原数据集中
- 重复的DNA序列[hash判定重复+滑动窗口+二进制编码之位运算]
- IPv4/IPv6、DHCP、网关、路由
- IOU 目标跟踪其一:IOU Tracker
- NATAPP内网穿透工具外网访问个人项目
- [numpy] broadcast mechanism
- Automated testing - unittest framework
- How to calculate the execution time in the function resource usage when using the timer trigger type to process database data?
- Typroa 拼写检查: 缺少对于 中文 的字典文件
猜你喜欢

82. (cesium article) cesium points move on 3D models

Read Plato & nbsp; Eplato of farm and the reasons for its high premium

One article to understand pychar shortcut key

Uncaught SyntaxError: redeclaration of let page

LeetCode每日一练 —— CM11 链表分割

Automated testing - unittest framework
![[dart] a programming language for cross end development](/img/e1/1167a322bb9f276f2e00fb12414d17.png)
[dart] a programming language for cross end development

Hexagon_V65_Programmers_Reference_Manual(7)

自动化测试----selenium(二)

CPDA|如何拥有数据分析思维?
随机推荐
MAPGIS 3D pipeline modeling awakens the pulse of urban underground pipelines
go --- air自动重新编译
R语言使用lm函数构建多元回归模型(Multiple Linear Regression)、并根据模型系数写出回归方程、使用deviance函数计算出模型的残差平方和
Obtain website shell permission based on file upload vulnerability
如何查看蓝牙耳机的蓝牙版本
R语言使用epiDisplay包的power.for.2p函数进行效用分析 ( 效能分析、Power analysis)、给定两个样本的比例值(proportions)、样本量计算效用值
“收割”NFT:200元淘宝买图,上链卖30万元
Uncaught SyntaxError: redeclaration of let page
Overview of understanding the physical layer of transmission media
认识传输介质物理层概述
重复的DNA序列[hash判定重复+滑动窗口+二进制编码之位运算]
Automatic test solution based on ATX
搭建discuz论坛并攻破盗取数据库
Codeforces 1706E 并查集 + 启发式合并 + ST 表
【历史上的今天】7 月 27 日:模型检测先驱出生;微软收购 QDOS;第一张激光照排的中文报纸
PHP code audit 6 - file contains vulnerability
[Numpy] 数组属性
NPDP | what kind of product manager can be called excellent?
Beijing / Shanghai / Guangzhou / Shenzhen dama-cdga/cdgp data governance certification registration conditions
Where is the program?