当前位置:网站首页>小心你的字典和样板代码
小心你的字典和样板代码
2022-07-30 21:24:00 【InfoQ】
字典键值对不匹配
data={
"image1.5": image_0_5,
"image1.0": image_1_0,
"image0.5": image_1_5,
}
data={
"image1.5": image_1_5,
"image1.0": image_1_0,
"image0.5": image_0_5,
}
a+b+cassertifraiseassert data['image1.5'].shape[-1] > data['image1.0'].shape[-1] > data['image0.5'].shape[-1]样板代码(boilerplate code)的遗漏
loss = loss_fn(model(X), Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
optimizer.zero_grad()
with autocast(enabled=args.use_fp16):
loss = loss_fn(model(X), Y)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()scaler.scale(loss).backward()自定义snippet
代码模块化
def clip_grad(params, mode, clip_cfg: dict):
if mode == "norm":
if "max_norm" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `max_norm`.")
torch.nn.utils.clip_grad_norm_(
params, max_norm=clip_cfg.get("max_norm"), norm_type=clip_cfg.get("norm_type", 2.0)
)
elif mode == "value":
if "clip_value" not in clip_cfg:
raise ValueError(f"`clip_cfg` must contain `clip_value`.")
torch.nn.utils.clip_grad_value_(params, clip_value=clip_cfg.get("clip_value"))
else:
raise NotImplementedError
class Scaler:
def __init__(
self, optimizer, use_fp16=False, *, set_to_none=False, clip_grad=False, clip_mode=None, clip_cfg=None
) -> None:
self.optimizer = optimizer
self.set_to_none = set_to_none
self.autocast = autocast(enabled=use_fp16)
self.scaler = GradScaler(enabled=use_fp16)
if clip_grad:
self.grad_clip_ops = partial(ops.clip_grad, mode=clip_mode, clip_cfg=clip_cfg)
else:
self.grad_clip_ops = None
def calculate_grad(self, loss):
self.scaler.scale(loss).backward()
if self.grad_clip_ops is not None:
self.scaler.unscale_(self.optimizer)
self.grad_clip_ops(chain(*[group["params"] for group in self.optimizer.param_groups]))
def update_grad(self):
self.scaler.step(self.optimizer)
self.scaler.update()
self.optimizer.zero_grad(set_to_none=self.set_to_none)
def state_dict(self):
r"""
Returns the state of the scaler as a :class:`dict`. It contains five entries:
* ``"scale"`` - a Python float containing the current scale
* ``"growth_factor"`` - a Python float containing the current growth factor
* ``"backoff_factor"`` - a Python float containing the current backoff factor
* ``"growth_interval"`` - a Python int containing the current growth interval
* ``"_growth_tracker"`` - a Python int containing the number of recent consecutive unskipped steps.
If this instance is not enabled, returns an empty dict.
.. note::
If you wish to checkpoint the scaler's state after a particular iteration, :meth:`state_dict`
should be called after :meth:`update`.
"""
return self.scaler.state_dict()
def load_state_dict(self, state_dict):
r"""
Loads the scaler state. If this instance is disabled, :meth:`load_state_dict` is a no-op.
Args:
state_dict(dict): scaler state. Should be an object returned from a call to :meth:`state_dict`.
"""
self.scaler.load_state_dict(state_dict)
scaler = pipeline.Scaler(
optimizer=optimizer,
use_fp16=cfg.train.use_amp,
set_to_none=cfg.train.optimizer.set_to_none,
clip_grad=cfg.train.grad_clip.enable,
clip_mode=cfg.train.grad_clip.mode,
clip_cfg=cfg.train.grad_clip.cfg,
)
with torch.cuda.amp.autocast(enabled=cfg.train.use_amp):
probs, loss, loss_str = model(
data=batch_data, iter_percentage=counter.curr_iter / counter.num_total_iters
)
loss = loss / cfg.train.grad_acc_step
scaler.calculate_grad(loss=loss)
if counter.every_n_iters(cfg.train.grad_acc_step): # Accumulates scaled gradients.
scaler.update_grad()
边栏推荐
- [Machine Learning] The Beauty of Mathematics Behind Gradient Descent
- ELF:加载过程
- ClickHouse删除数据之delete问题详解
- Apache DolphinScheduler新一代分布式工作流任务调度平台实战-上
- KingbaseESV8R6 snapshot too old的配置和测试
- kubernetes
- 对List集合中每个对象元素按时间顺序排序
- Image Restoration by Estimating Frequency Distribution of Local Patches
- How to make a deb package
- 数据指标口径不统一、重复开发?亿信ABI指标管理平台帮你解决
猜你喜欢

基于ABP实现DDD--领域逻辑和应用逻辑

牛客小白月赛53 A-E

新书上市 |《谁在掷骰子?》在“不确定性时代”中确定前行

大家都在用的plm项目管理软件有哪些

KingbaseES V8R6备份恢复案例之---同一数据库创建不同stanza备份

QUALITY-GATED CONVOLUTIONAL LSTM FOR ENHANCING COMPRESSED VIDEO

Typescript 严格模式有多严格?

MySql创建数据表

Use the map function to operate on each element in the list It seems that you don't need a map

mysql死锁
随机推荐
三层架构简单配置
(7/29) Basic board minimum spanning tree prim+kruskal
【信息安全技术】RSA算法的研究及不同优化策略的比较
Use the map function to operate on each element in the list It seems that you don't need a map
Quick Master QML Chapter 6 Animation
Automatically generate test modules using JUnit4 and JUnitGenerator V2.0 in IDEA
(7/29)基础板子最小生成树prim+kruskal
Babbitt | Metaverse Daily Must Read: The shuffling is coming, will the digital Tibetan industry usher in a new batch of leaders in the second half?Will there be new ways to play?...
WinDbg实践--入门篇
Deep Non-Local Kalman Network for VideoCompression Artifact Reduction
What is the common factor
MySQL 用户授权
Motion Tuned Spatio-temporal Quality Assessmentof Natural Videos
【Network Security Column Directory】--Penguin Column Navigation
Outsourcing worked for three years, it was abolished...
ValueError: Append mode is not supported with xlsxwriter解决方案
用于视频压缩伪影消除的深度卡尔曼滤波网络
系统结构考点之多级混洗交换网络
About the error of SFML Rect.inl file
Deep Non-Local Kalman Network for VideoCompression Artifact Reduction