当前位置:网站首页>小心你的字典和样板代码
小心你的字典和样板代码
2022-07-30 21:24:00 【InfoQ】
字典键值对不匹配
data={
"image1.5": image_0_5,
"image1.0": image_1_0,
"image0.5": image_1_5,
}
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
}
a+b+cassertifraiseassert data['image1.5'].shape[-1] > data['image1.0'].shape[-1] > data['image0.5'].shape[-1]样板代码(boilerplate code)的遗漏
loss = loss_fn(model(X), Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
with autocast(enabled=args.use_fp16):
loss = loss_fn(model(X), Y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()scaler.scale(loss).backward()自定义snippet
代码模块化
def clip_grad(params, mode, clip_cfg: dict):
if mode == "norm":
if "max_norm" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `max_norm`.")
torch.nn.utils.clip_grad_norm_(
params, max_norm=clip_cfg.get("max_norm"), norm_type=clip_cfg.get("norm_type", 2.0)
)
elif mode == "value":
if "clip_value" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `clip_value`.")
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
else:
raise NotImplementedError
class Scaler:
def __init__(
self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
) -> None:
self.optimizer = optimizer
self.set_to_none = set_to_none
self.autocast = autocast(enabled=use_fp16)
self.scaler = GradScaler(enabled=use_fp16)
if clip_grad:
self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
else:
self.grad_clip_ops = None
def calculate_grad(self, loss):
self.scaler.scale(loss).backward()
if self.grad_clip_ops is not None:
self.scaler.unscale_(self.optimizer)
self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
def update_grad(self):
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=self.set_to_none)
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return self.scaler.state_dict()
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.scaler.load_state_dict(state_dict)
scaler = pipeline.Scaler(
optimizer=optimizer,
use_fp16=cfg.train.use_amp,
set_to_none=cfg.train.optimizer.set_to_none,
clip_grad=cfg.train.grad_clip.enable,
clip_mode=cfg.train.grad_clip.mode,
clip_cfg=cfg.train.grad_clip.cfg,
)
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
probs, loss, loss_str = model(
data=batch_data, iter_percentage=counter.curr_iter / counter.num_total_iters
)
loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
scaler.update_grad()
边栏推荐
猜你喜欢

ENS emoji domain name is on fire!Hype or opportunity?

关于MySQL主从复制的数据同步延迟问题

LeetCode · 23. Merge K ascending linked lists · recursion · iteration

面试难题:分布式 Session 实现难点,这篇就够!

2022-07-29 mysql/stonedb慢SQL-Q17-分析

【信息安全技术】RSA算法的研究及不同优化策略的比较

R package调试

What is the common factor

Navicat连接MySQL时弹出:1045:Access denied for user ‘root’@’localhost’

【菜鸡含泪总结】如何用pip、anaconda安装库
随机推荐
LeetCode·23.合并K个升序链表·递归·迭代
vlan简单实验
y82.第四章 Prometheus大厂监控体系及实战 -- 监控扩展和prometheus 联邦(十三)
LeetCode·Daily Question·952. Calculate Maximum Component Size by Common Factor·Union Check
Navicat连接MySQL时弹出:1045:Access denied for user ‘root’@’localhost’
socket:内核初始化及创建流(文件)详细过程
About the data synchronization delay of MySQL master-slave replication
关于MySQL主从复制的数据同步延迟问题
Niu Ke Xiaobaiyue Race 53 A-E
DPW-SDNet: Dual Pixel-Wavelet Domain Deep CNNs for Soft Decoding of JPEG-Compressed Images
【问题】Mysql Waiting for table metadata lock 解决方案 修改lock_wait_timeout时间
手动从0搭建ABP框架-ABP官方完整解决方案和手动搭建简化解决方案实践
【信息安全技术】RSA算法的研究及不同优化策略的比较
【限时福利】21天学习挑战赛 - MySQL从入门到精通
微信公众号授权登录后报redirect_uri参数错误的问题
字节对齐之C语言犄角旮旯的知识
[Deep Learning] Understanding of Domain Adaptation in Transfer Learning and Introduction of 3 Techniques
共用体、共用体与结构体的区别、枚举之C语言犄角旮旯的知识
JDBC(详解)
触摸屏状态机