当前位置:网站首页>Be careful with your dictionaries and boilerplate code
Be careful with your dictionaries and boilerplate code
2022-07-30 21:33:00 【InfoQ】
The dictionary keys does not match
data={
"image1.5": image_0_5,
"image1.0": image_1_0,
"image0.5": image_1_5,
}
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
}
a+b+c
assert
if
raise
assert data['image1.5'].shape[-1] > data['image1.0'].shape[-1] > data['image0.5'].shape[-1]
样板代码(boilerplate code)的遗漏
loss = loss_fn(model(X), Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
with autocast(enabled=args.use_fp16):
loss = loss_fn(model(X), Y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scaler.scale(loss).backward()
自定义snippet
代码模块化
def clip_grad(params, mode, clip_cfg: dict):
if mode == "norm":
if "max_norm" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `max_norm`.")
torch.nn.utils.clip_grad_norm_(
params, max_norm=clip_cfg.get("max_norm"), norm_type=clip_cfg.get("norm_type", 2.0)
)
elif mode == "value":
if "clip_value" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `clip_value`.")
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
else:
raise NotImplementedError
class Scaler:
def __init__(
self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
) -> None:
self.optimizer = optimizer
self.set_to_none = set_to_none
self.autocast = autocast(enabled=use_fp16)
self.scaler = GradScaler(enabled=use_fp16)
if clip_grad:
self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
else:
self.grad_clip_ops = None
def calculate_grad(self, loss):
self.scaler.scale(loss).backward()
if self.grad_clip_ops is not None:
self.scaler.unscale_(self.optimizer)
self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
def update_grad(self):
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=self.set_to_none)
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return self.scaler.state_dict()
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.scaler.load_state_dict(state_dict)
scaler = pipeline.Scaler(
optimizer=optimizer,
use_fp16=cfg.train.use_amp,
set_to_none=cfg.train.optimizer.set_to_none,
clip_grad=cfg.train.grad_clip.enable,
clip_mode=cfg.train.grad_clip.mode,
clip_cfg=cfg.train.grad_clip.cfg,
)
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
probs, loss, loss_str = model(
data=batch_data, iter_percentage=counter.curr_iter / counter.num_total_iters
)
loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
scaler.update_grad()
边栏推荐
- MySQL 有这一篇就够(呕心狂敲37k字,只为博君一点赞!!!)
- [Deep Learning] Understanding of Domain Adaptation in Transfer Learning and Introduction of 3 Techniques
- Day 16 of HCIP
- 微信公众号授权登录后报redirect_uri参数错误的问题
- 数字货币期货现货交易技巧,把握关键进场的买入点!(纯干货)
- 我是如何让公司后台管理系统焕然一新的(上) -性能优化
- WeChat reading, export notes
- MySql创建数据表
- Why do so many people who teach themselves software testing give up later...
- MySQL 灵魂 16 问,你能撑到第几问?
猜你喜欢
nVisual网络可视化管理平台功能和价值点
【信息安全技术】RSA算法的研究及不同优化策略的比较
8 ways to get element attributes in JS
KingbaseES V8R6备份恢复案例之---同一数据库创建不同stanza备份
MySQL Soul 16 Questions, How Many Questions Can You Last?
mysql死锁
系统结构考点之多级混洗交换网络
MySQL 游标
navicat连接MySQL报错:1045 - Access denied for user ‘root‘@‘localhost‘ (using password YES)
面试难题:分布式 Session 实现难点,这篇就够!
随机推荐
基于ABP实现DDD--实体创建和更新
navicat新建数据库
ClickHouse删除数据之delete问题详解
IDEA2018.3.5 cancel double-click Shift shortcut
vlan简单实验
类和对象——上
DPW-SDNet: Dual Pixel-Wavelet Domain Deep CNNsfor Soft Decoding of JPEG-Compressed Images
Motion Tuned Spatio-temporal Quality Assessmentof Natural Videos
MySQL删除表数据 MySQL清空表命令 3种方法
微信公众号授权登录后报redirect_uri参数错误的问题
为什么那么多自学软件测试的人,后来都放弃了...
A simple rich text editor
(7/29) Basic board minimum spanning tree prim+kruskal
go语言慢速入门——流程控制语句
MySQL分页查询的5种方法
基于ABP实现DDD--仓储实践
[Deep Learning] Understanding of Domain Adaptation in Transfer Learning and Introduction of 3 Techniques
nVisual网络可视化管理平台功能和价值点
Deep Non-Local Kalman Network for VideoCompression Artifact Reduction
系统结构考点之流水线向量点积