当前位置:网站首页>自定义学习率
自定义学习率
2022-07-27 18:27:00 【jstzwjr】
warm up + 自定义学习率策略(如cos)
from math import cos, pi
from torch.optim.lr_scheduler import _LRScheduler
class CustomScheduler(_LRScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
self.base_lr = base_lr
self.warmup_lr_init = 0.0001
self.max_steps: int = max_steps
self.warmup_steps: int = warmup_steps
self.power = 2
self.eta_min = eta_min
super(CustomScheduler, self).__init__(optimizer, -1, False)
self.last_epoch = last_epoch
def get_warmup_lr(self):
alpha = float(self.last_epoch) / float(self.warmup_steps)
return [self.base_lr * alpha for _ in self.optimizer.param_groups]
def get_lr(self):
if self.last_epoch == -1:
return [self.warmup_lr_init for _ in self.optimizer.param_groups]
if self.last_epoch < self.warmup_steps:
return self.get_warmup_lr()
else:
alpha = self.func()
return [(self.base_lr-self.eta_min) * alpha + self.eta_min for _ in self.optimizer.param_groups]
def func(self):
alpha = (
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
class PolyScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = pow(
1
- float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps),
self.power,
)
return alpha
class CosineScheduler(CustomScheduler):
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, eta_min=0, last_epoch=-1):
super().__init__(optimizer, base_lr, max_steps, warmup_steps, eta_min, last_epoch)
def func(self):
alpha = cos(
pi / 2
* float(self.last_epoch - self.warmup_steps)
/ float(self.max_steps - self.warmup_steps))
return alpha
边栏推荐
- Download of MySQL driver jar package -- nanny tutorial
- 14天鸿蒙设备开发实战-第七章 设备联网上云 学习笔记
- R语言使用dplyr包进行数据聚合统计计算滑动窗口统计值(Window Statistics)、计算滑动分组均值(mean)并合并生成的统计数据到原数据集中
- Best practices for Oracle kingbasees migration of Jincang database (2. Overview)
- 金仓数据库 KingbaseES 异构数据库移植指南 (4. 应用迁移流程)
- 认识传输介质网络通信的介质
- Know the transmission medium, the medium of network communication
- Arduino development (II)_ RGB light control method based on Arduino uno development board
- SLIM:自监督点云场景流与运动估计(ICCV 2021)
- [numpy] array index and slice
猜你喜欢
![[Numpy] 数组属性](/img/eb/a27c24deeb7951828cdfbaa88c059c.png)
[Numpy] 数组属性

RK3399平台入门到精通系列讲解(导读篇)21天学习挑战介绍
![[program life]](/img/68/9a337d37490ad5cd75095cc5efd5e4.jpg)
[program life] "stage summary" - unwilling to be ordinary

Hcip day 5

【Dart】一门为跨端开发而生的编程语言

Introduction to JVs Foundation

How does the industrial switch enter the web management interface?

Hexagon_V65_Programmers_Reference_Manual(6)

MySQL驱动jar包的下载--保姆教程

Best practices for Oracle kingbasees migration of Jincang database (4. Oracle database migration practice)
随机推荐
Best practices for Oracle kingbasees migration of Jincang database (4. Oracle database migration practice)
Automatic test solution based on ATX
openresty lua-resty-dns 域名解析
说透缓存一致性与内存屏障
认识网络模型OSI模型
[design tutorial] yolov7 target detection network interpretation
R语言使用dplyr包进行数据聚合统计计算滑动窗口统计值(Window Statistics)、计算滑动分组均值(mean)并合并生成的统计数据到原数据集中
JVS公众号登陆配置
[efficiency] abandon notepad++, this open source substitute is more awesome!
国际权威认可!OceanBase入选Forrester Translytical数据平台报告
如何解决tp6控制器不存在:app\controller\Index
五大知名人士对于AI的忧虑
How does the industrial switch enter the web management interface?
[deep learning] pytoch tensor
NPDP|什么样的产品经理可以被称为优秀?
Openresty Lua resty core use
认识传输介质物理层概述
People call this software testing engineer. You're just making a living (with HR interview Dictionary)
knife4j通过js动态刷新全局参数
Introduction to JVs Foundation