当前位置:网站首页>Problems and extensions of the monocular depth estimation model featdepth in practice
Problems and extensions of the monocular depth estimation model featdepth in practice
2022-07-25 14:03:00 【Apple sister】
About Featdepth The principle and source code interpretation of the model , You can refer to the following two blogs :
Blogger in Featdepth Many problems have been encountered in actual combat , Such as distributed multi card training tool DDP Problems in use , Include GPU Allocated bug、 How to print multi card global loss、 How to synchronize globally BN(SyncBN)、 How to do data shuffle etc. . Maybe it's because the source code takes a long time , So there are imperfections . Here, we will discuss the solutions to these problems .
One 、GPU The distribution of bug problem
Bloggers study this issue in depth DDP After the mechanism, the following blog posts are sorted out for reference :
although DDP The global update of the medium gradient is automatic , But the source code is customized DistOptimizerHook:
class DistOptimizerHook(OptimizerHook):
def __init__(self, grad_clip=None, coalesce=True, bucket_size_mb=-1):
self.grad_clip = grad_clip
self.coalesce = coalesce
self.bucket_size_mb = bucket_size_mb
def after_train_iter(self, runner):
runner.optimizer.zero_grad()
runner.outputs['loss'].backward()
allreduce_grads(runner.model, self.coalesce, self.bucket_size_mb)
if self.grad_clip is not None:
self.clip_grads(runner.model.parameters())
runner.optimizer.step()
This calls allreduce_grads function , use coalesce Parameter configures whether to follow tensor Of type grouping , And into contiguous 1D buffer( Continuous in memory ) Proceed again reduce, Memory optimization may be possible .
The principles involved in the remaining three issues are also analyzed above , Suggest reading first , The following is an introduction to featdepth Specific solutions in .
Two 、log Printer system and possible extensions
featdepth Medium log The printing part is also integrated in mmcv In frame , Use registration hook The way . stay config You can see in the document :
log_config = dict(interval=500,
hooks=[dict(type='TextLoggerHook'),])
I've set it up here hook The name is TextLoggerHook, The interval is 500,TextLoggerHook yes mmcv Built in printing log Of hook type , It can be used runner.register_training_hooks Unified registration :
runner.register_training_hooks(cfg.lr_config,
optimizer_config,
cfg.checkpoint_config,
cfg.log_config)
In the function and through register_logger_hooks(logconfig) take TextLoggerHook as well as interval Sign up with . stay mmcv Of TextLoggerHook Class source code, you can see , It's from runner.log_buffer.output Take out the content to be printed , Reuse runner.logger Print , Don't paste the specific code . The former is used to store log contents , The latter is used to print logs .
stay runner.train() and runner.val() Will call runner.runiter(), Internally, the results of each iteration will be log_vars Stored in runner.log_buffer in , And then in Loggerhook Middle computation interval The average of times .
def run_iter(self, data_batch, train_mode, **kwargs):
if self.batch_processor is not None:
outputs = self.batch_processor(
self.model, data_batch, train_mode=train_mode, **kwargs)
elif train_mode:
outputs = self.model.train_step(data_batch, self.optimizer,
**kwargs)
else:
outputs = self.model.val_step(data_batch, self.optimizer, **kwargs)
if not isinstance(outputs, dict):
raise TypeError('"batch_processor()" or "model.train_step()"'
'and "model.val_step()" must return a dict')
if 'log_vars' in outputs:
self.log_buffer.update(outputs['log_vars'], outputs['num_samples'])
self.outputs = outputs
and log_vars It's customized batch_processor( stay trainner.py) Stored in the , The content is the items of each iteration loss value :
def batch_processor(model, data, train_mode):
data = change_input_variable(data)
model_out, losses = model(data)
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
'{} is not a tensor or list of tensors'.format(loss_name))
loss = sum(_value for _key, _value in log_vars.items())
log_vars['loss'] = loss
new_log_vars=OrderedDict()
for name in log_vars:
new_log_vars[str(name)] = log_vars[name].item()
outputs = dict(loss=loss,
log_vars=new_log_vars,
num_samples=len(data[('color', 0 , 0)].data))
return outputs
The details are printed log The place of , That is to say runner.logger, Only 0 Number GPU Printed . So it can be seen that , Only the master node can be printed in the source code GPU Upper loss. And because the source code is only introduced in the training stage train_dataset, So you can only print the training log , The verification set is calculated when there is a truth value rmse And other indicators , If there is no true value, the validation set is not displayed loss.
Thus, two extensions can be made : First, print the global average loss, The second is to add the validation set loss.
Print global average loss Can be in batchprocessor Manually in loss Global synchronization of , That is to say all_reduce, This also acts on train and val Pattern .
for loss_name, loss_value in losses.items():
dist.all_reduce(loss_value.div_(torch.cuda.device_count()))
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
……
Displays the... Of the validation set loss Can be in data_loader Add validation set :
data_loaders = [build_dataloader(dataset_train,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
dist=True),
build_dataloader(dataset_val,
cfg.imgs_per_gpu,
cfg.workers_per_gpu,
dist=True)
]
But there is still a problem : In model structure (net.py) Lieutenant general val The mode is set to no output loss:
def forward(self, inputs):
outputs = self.DepthDecoder(self.DepthEncoder(inputs["color_aug", 0, 0]))
if self.training:
outputs.update(self.predict_poses(inputs))
features = self.Encoder(inputs[("color", 0, 0)])
outputs.update(self.Decoder(features, 0))
loss_dict = self.compute_losses(inputs, outputs, features)
return outputs, loss_dict
return outputs
Here we need to change it to :
def forward(self, inputs):
outputs = self.DepthDecoder(self.DepthEncoder(inputs["color_aug", 0, 0]))
outputs.update(self.predict_poses(inputs))
features = self.Encoder(inputs[("color", 0, 0)])
outputs.update(self.Decoder(features, 0))
loss_dict = self.compute_losses(inputs, outputs, features)
return outputs, loss_dict
And in get_dataset Lieutenant general train and val The input frame number of mode is changed to the default value ( Source code val Mode only inputs one frame ):
dataset = dataset(cfg.in_path,
filenames,
cfg.height,
cfg.width,
#cfg.frame_ids if training else [0], It needs to be modified here
cfg.frame_ids,
is_train=training,
img_ext=img_ext,
gt_depth_path=cfg.gt_depth_path)
And then in config file validate Set to True that will do . You can see in the runner The default setting in is train Mode each interval Print once ,val Mode each epoch Print once . In this way, you can see all the log output .
3、 ... and 、 Global synchronization BN problem
This problem is also explained in the above blog , Specifically in trainer.py Of _dist_train() Functional DDP Execute the following code before packaging :
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(local_rank)
Four 、 data shuffle problem
Data in the source code shuffle Also by registering hook Realized :
runner.register_hook(DistSamplerSeedHook())
DistSamplerSeedHook Source code :
@HOOKS.register_module
class DistSamplerSeedHook(Hook):
def before_epoch(self, runner):
runner.data_loader.sampler.set_epoch(runner.epoch)
At every epoch The random number seed was set before shuffle, No need to reset .
Other practical problems will continue to be updated later , Please pay attention to .
边栏推荐
- AI model risk assessment Part 1: motivation
- Mongodb源码部署以及配置
- redis集群的三种方式
- Workplace "digital people" don't eat or sleep 007 work system, can you "roll" them?
- leetcode1 --两数之和
- Data analysis interview records 1-5
- Tm1638 LED digital display module Arduino drive code
- NoSQL,关系型数据库,行列数据库对比、类比
- 「数字安全」警惕 NFT的七大骗局
- [原创]九点标定工具之机械手头部相机标定
猜你喜欢

Canal realizes MySQL data synchronization

2271. Maximum number of white bricks covered by blanket ●●

Package management apt, dpkg

G027-op-ins-rhel-04 RedHat openstack creates a customized qcow2 format image

Mxnet implementation of densenet (dense connection network)

「数字安全」警惕 NFT的七大骗局

Brush questions - Luogu -p1161 turn on the light

Dr. Berkeley's "machine learning engineering" big truth; AI vice president '2022 ml job market' analysis; Large list of semiconductor start-ups; Large scale video face attribute data set; Cutting edge

Tm1637 four digit LED display module Arduino driver with second dot

Digital Twins - cognition
随机推荐
高版本MySQL服务端安装不上怎么办,忘记密码(MySQL8.0.29)?
Tm1638 LED digital display module Arduino drive code
Mongodb source code deployment and configuration
[configure hifive1 revb] the device manager does not recognize the port, and can not connect to j-link via USB
Emergency science | put away this summer safety guide and let children spend the summer vacation safely!
leetcode--四数相加II
einsum(): operands do not broadcast with remapped shapes [original->remapped]: [1, 144, 20, 17]->[1,
Brush questions - Luogu -p1150 Peter's smoke
Turn off automatic update when brew executes commands
Internal error of LabVIEW
Brush questions - Luogu -p1085 unhappy Jinjin
金鱼哥RHCA回忆录:CL210管理存储--对象存储
[原创]九点标定工具之机械手头部相机标定
It is predicted that 2021 will accelerate the achievement of super automation beyond RPA
Brush questions - Luogu -p1152 happy jump
dp-851
移动端网站,独立APP,网站排名策略有哪些?
Common problems in the use of wireless vibrating wire acquisition instrument
伯克利博士『机器学习工程』大实话;AI副总裁『2022 ML就业市场』分析;半导体创业公司大列表;大规模视频人脸属性数据集;前沿论文 | ShowMeAI资讯日报
MXNet对DenseNet(稠密连接网络)的实现