当前位置:网站首页>MMdetection之build_optimizer模块解读
MMdetection之build_optimizer模块解读
2022-06-10 16:53:00 【武乐乐~】
前言
前面文章build_dataset,build_dataloader,build_model均以做了详细的介绍,而optimizer作为“炼丹”的最后一个条件,本文将介绍mmdetection是如何构建优化器的。
1、总体流程

总体流程和构建model过程类似。首先mmdetection建立了一个优化器注册器,里面注册了DefaultOptimizerConstructor优化器类。然后借助build_from_cfg函数从优化器配置字典中实例了一个optimizer对象。接下来,将详细介绍各个组件的内部原理。
2、优化器配置字典
本文依旧以faster_rcnn_r50_fpn.py默认配置文件为例。其中,涉及optimizer的字段如下:
# optimizer
optimizer = dict(type='SGD', lr=0.00125, momentum=0.9, weight_decay=0.0001)
可以看出默认使用的是SGD优化器。
3、优化器注册器
mmdetection建立了两个注册器:
OPTIMIZERS = Registry('optimizer')
OPTIMIZER_BUILDERS = Registry('optimizer builder')
3.1、注册器OPTIMIZERS
其中,OPTIMIZERS注册器里面添加了一些pytorch提供的优化器,可以看下图:
这里我简单介绍下mmdetection构建这一块的过程(mmcv/runners/optimizer/builder.py):通过dir方法遍历torch.optim,然后利用register_module()(_optim)完成注册。
def register_torch_optimizers():
torch_optimizers = []
for module_name in dir(torch.optim):
if module_name.startswith('__'):
continue
_optim = getattr(torch.optim, module_name)
if inspect.isclass(_optim) and issubclass(_optim,
torch.optim.Optimizer):
OPTIMIZERS.register_module()(_optim) # 此处往OPTIMIZERS里面注册了torch中默认的优化器。
torch_optimizers.append(module_name)
return torch_optimizers
TORCH_OPTIMIZERS = register_torch_optimizers()
3.2、注册器OPTIMIZER_BUILDERS
另一个注册器主要注册了下面的这个类(mmcv/runner/optimizer/default_constructor.py),这里我仅仅截取了类初始化部分。
@OPTIMIZER_BUILDERS.register_module()
class DefaultOptimizerConstructor:
"""Default constructor for optimizers Args: model (:obj:`nn.Module`): The model with parameters to be optimized. optimizer_cfg (dict): The config dict of the optimizer. Positional fields are - `type`: class name of the optimizer. Optional fields are - any arguments of the corresponding optimizer type, e.g., lr, weight_decay, momentum, etc. paramwise_cfg (dict, optional): Parameter-wise options. Example 1: >>> model = torch.nn.modules.Conv1d(1, 1, 1) >>> optimizer_cfg = dict(type='SGD', lr=0.01, momentum=0.9, >>> weight_decay=0.0001) >>> paramwise_cfg = dict(norm_decay_mult=0.) >>> optim_builder = DefaultOptimizerConstructor( >>> optimizer_cfg, paramwise_cfg) >>> optimizer = optim_builder(model) Example 2: >>> # assume model have attribute model.backbone and model.cls_head >>> optimizer_cfg = dict(type='SGD', lr=0.01, weight_decay=0.95) >>> paramwise_cfg = dict(custom_keys={ '.backbone': dict(lr_mult=0.1, decay_mult=0.9)}) >>> optim_builder = DefaultOptimizerConstructor( >>> optimizer_cfg, paramwise_cfg) >>> optimizer = optim_builder(model) >>> # Then the `lr` and `weight_decay` for model.backbone is >>> # (0.01 * 0.1, 0.95 * 0.9). `lr` and `weight_decay` for >>> # model.cls_head is (0.01, 0.95). """
def __init__(self, optimizer_cfg, paramwise_cfg=None):
if not isinstance(optimizer_cfg, dict):
raise TypeError('optimizer_cfg should be a dict',
f'but got {
type(optimizer_cfg)}')
self.optimizer_cfg = optimizer_cfg
self.paramwise_cfg = {
} if paramwise_cfg is None else paramwise_cfg
self.base_lr = optimizer_cfg.get('lr', None)
self.base_wd = optimizer_cfg.get('weight_decay', None)
self._validate_cfg()
4、实例优化器对象
在有了配置字典和注册器之后,然后就可以实例化优化器对象了。而构建优化器的入口在mmdet/apis/trian.py文件中,代码如下:
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
这里看下build_optimizer函数:
def build_optimizer_constructor(cfg):
return build_from_cfg(cfg, OPTIMIZER_BUILDERS) # 完成实例优化器对象
def build_optimizer(model, cfg):
optimizer_cfg = copy.deepcopy(cfg)
constructor_type = optimizer_cfg.pop('constructor', #optimizer_cfg字典中无"constructor"这个键,则返回DefaultOptimizerConstructor
'DefaultOptimizerConstructor')
paramwise_cfg = optimizer_cfg.pop('paramwise_cfg', None) # 同上
optim_constructor = build_optimizer_constructor( # 实际上调用的是build_from_cfg函数
dict(
type=constructor_type,
optimizer_cfg=optimizer_cfg,
paramwise_cfg=paramwise_cfg))
optimizer = optim_constructor(model)
return optimizer
从代码中可以看出,build_optimizer内部调用了build_optimizer_constructor函数,进而调用了build_from_cfg完成了类的实例化,即代码中的optim_constructor对象。之后,到了有意思的部分了。前面我们仅仅用到了注册器OPTIMIZER_BUILDERS,却没有用到注册器OPTIMIZERS。那么是在哪里调用的呢?调用代码是:optimizer=optim_constructor(model)。
现在回头在看下3.2节中的DefaultOptimizerConstructor类。内部实现了__call__方法。截取这部分代码:
def __call__(self, model):
if hasattr(model, 'module'):
model = model.module
optimizer_cfg = self.optimizer_cfg.copy()
# if no paramwise option is specified, just use the global setting
if not self.paramwise_cfg:
optimizer_cfg['params'] = model.parameters()
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
# set param-wise lr and weight decay recursively
params = []
self.add_params(params, model)
optimizer_cfg['params'] = params
return build_from_cfg(optimizer_cfg, OPTIMIZERS)
代码中借助build_from_cfg(optimizer_cfg,OPTIMIZERS)完成了真正的优化器对象的建立。
总结
本文主要介绍mmdetection中构建优化器的过程。当然,还有许多代码细节值得学习。总的来说,由于实际使用优化器过程中,会有各种各样灵活的设定。假如只借助单一的注册器OPTIMIZERS,势必会不方便。而mmdetection经过“工厂”—OPTIMIZER_BUILDERS就能给优化器提供灵活性(比如仅仅优化部分参数或者添加优化的参数等)。这种设计模式值得学习。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。
拓展阅读资料
mmcv中Config类介绍
mmcv之Registry类介绍
mmdetection之dataset类构建
mmdetection之dataloader类构建
mmdetection之model构建
mmdetection训练自己coco数据集
边栏推荐
- Nat. Rev. Drug Discov. | Application of AI in small molecule drug discovery: an upcoming wave?
- 开源项目 PM 浅谈如何设计官网
- When V-IF and V-for need to be used at the same time
- 使用KuboardSpray安装kubernetes(v1.23.1)
- Station B doesn't want to be a "conscience aiyouteng"
- 【抬杠C#】如何实现接口的base调用
- Xinsi technology performed well in the Gartner application security test key capability report 2022 and won the highest score among the five common use cases
- Leetcode 929. 独特的电子邮件地址
- See how advanced technology changes human life
- 力扣 20. 有效的括号
猜你喜欢
Redis operation set, Zset, hash data types and use of visualization tools

Take you to a preliminary understanding of the basic mechanism of classes and objects

牛客网:表达式求值

路由器实验之serial接口的静态路由配置(补充)

《华为数据之道》读书笔记

Swift 3pThread tool Promise Pipeline Master/Slave Serial Thread confinement Serial queue

Solve the problem that idea is stuck in opening a project

Feign based remote call

牛客网:两数之和

蓝桥杯_挑选子串_组合数学_乘法原理_ / 尺取法
随机推荐
Nat. Commun. | Knowledge integration and decision support for accelerating the discovery of antibiotic resistance genes
Nacos configuration management
牛客网:两数之和
软件项目管理 6.10.成本预算
Fabric. JSON for JS compact output
Fabric. Keep the original level when JS element is selected
亟需丰富智能家居产品线,扫地机器人赛道上挤得下萤石吗?
厉害了,工信部推出 “一键解绑” 手机号绑定的互联网账号,堪称神器
Xinsi technology helps Israel visuality systems promote the "left shift" of security
pands pd.DataFrame()函数详细解析
《华为数据之道》读书笔记
2022年茶艺师(中级)操作证考试题库及模拟考试
绘制混淆矩阵
2022 version of idea graphical interface GUI garbled code solution super detailed simple version
Designing drugs with code: are we here yet?
Facebook AI | learning reverse folding from millions of prediction structures
如何运行plink软件--三种方法
Numpy np set_ Usage of printoptions () -- control output mode
Detailed derivation of perspective projection transformation and related applications
Mapbox GL development tutorial (11): loading line layers