当前位置:网站首页>小心你的字典和样板代码
小心你的字典和样板代码
2022-07-30 21:24:00 【InfoQ】
字典键值对不匹配
data={
"image1.5": image_0_5,
"image1.0": image_1_0,
"image0.5": image_1_5,
}
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
}
a+b+c
assert
if
raise
assert data['image1.5'].shape[-1] > data['image1.0'].shape[-1] > data['image0.5'].shape[-1]
样板代码(boilerplate code)的遗漏
loss = loss_fn(model(X), Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
with autocast(enabled=args.use_fp16):
loss = loss_fn(model(X), Y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scaler.scale(loss).backward()
自定义snippet
代码模块化
def clip_grad(params, mode, clip_cfg: dict):
if mode == "norm":
if "max_norm" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `max_norm`.")
torch.nn.utils.clip_grad_norm_(
params, max_norm=clip_cfg.get("max_norm"), norm_type=clip_cfg.get("norm_type", 2.0)
)
elif mode == "value":
if "clip_value" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `clip_value`.")
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
else:
raise NotImplementedError
class Scaler:
def __init__(
self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
) -> None:
self.optimizer = optimizer
self.set_to_none = set_to_none
self.autocast = autocast(enabled=use_fp16)
self.scaler = GradScaler(enabled=use_fp16)
if clip_grad:
self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
else:
self.grad_clip_ops = None
def calculate_grad(self, loss):
self.scaler.scale(loss).backward()
if self.grad_clip_ops is not None:
self.scaler.unscale_(self.optimizer)
self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
def update_grad(self):
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=self.set_to_none)
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return self.scaler.state_dict()
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.scaler.load_state_dict(state_dict)
scaler = pipeline.Scaler(
optimizer=optimizer,
use_fp16=cfg.train.use_amp,
set_to_none=cfg.train.optimizer.set_to_none,
clip_grad=cfg.train.grad_clip.enable,
clip_mode=cfg.train.grad_clip.mode,
clip_cfg=cfg.train.grad_clip.cfg,
)
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
probs, loss, loss_str = model(
data=batch_data, iter_percentage=counter.curr_iter / counter.num_total_iters
)
loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
scaler.update_grad()
边栏推荐
猜你喜欢
导航栏----个人中心 Dropdown
外包干了三年,废了...
Babbitt | Metaverse Daily Must Read: The shuffling is coming, will the digital Tibetan industry usher in a new batch of leaders in the second half?Will there be new ways to play?...
nVisual网络可视化管理平台功能和价值点
navicat新建数据库
Image Restoration by Estimating Frequency Distribution of Local Patches
Deep Non-Local Kalman Network for VideoCompression Artifact Reduction
手动从0搭建ABP框架-ABP官方完整解决方案和手动搭建简化解决方案实践
Image Restoration by Estimating Frequency Distribution of Local Patches
【Network Security Column Directory】--Penguin Column Navigation
随机推荐
JS中获取元素属性的8大方法
WinDbg实践--入门篇
Markdown的使用
牛客网——业务分析-提取值
【深度学习】对迁移学习中域适应的理解和3种技术的介绍
不用bs4的原因居然是名字太长?爬取彩票开奖信息
DPW-SDNet: Dual Pixel-Wavelet Domain Deep CNNsfor Soft Decoding of JPEG-Compressed Images
手动从0搭建ABP框架-ABP官方完整解决方案和手动搭建简化解决方案实践
Outsourcing worked for three years, it was abolished...
go慢速入门——函数
ValueError: Append mode is not supported with xlsxwriter解决方案
大家都在用的plm项目管理软件有哪些
【信息安全技术】RSA算法的研究及不同优化策略的比较
LeetCode·Daily Question·952. Calculate Maximum Component Size by Common Factor·Union Check
What is the common factor
Union, the difference between union and structure, the knowledge of enumeration of C language corners
走进Redis,让你重新认识redis。绝不是表面
JUC原子类详解
qt使用动态库(DLL)
go语言慢速入门——流程控制语句