当前位置:网站首页>Detailed explanation of the loss module of mmdet
Detailed explanation of the loss module of mmdet
2022-06-30 05:13:00 【Wu lele~】
List of articles
Preface
This introduction mmdet Of the loss function , It will be gradually expanded in the future mmdet Precautions and methods of using the loss function in .
1、mmdet Introduction to the loss function module in
1.1. Loss The Registrar of
Let's look at the code :mmdet/models/builder.py
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS) # There is one more parent Parameters , Not for the time being
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS # Loss Registrar
DETECTORS = MODELS
here MODELS The register is also assigned to other modules , Why are subsequent operations performed in
1.2. register L1 Loss()
@LOSSES.register_module()
class L1Loss(nn.Module):
"""L1 loss. Args: reduction (str, optional): The method to reduce the loss. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of loss. """
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function. Args: pred (torch.Tensor): Prediction box . such as [N]; target (torch.Tensor): True value . such as [N]; weight (torch.Tensor, optional): The weight of each sample ,shape = [N], Defaults to None. avg_factor (int, optional): The coefficient that controls the total loss , Action follows loss_weight Heavy .Defaults to None. reduction_override (str, optional): Action follows reduction Heavy . Defaults to None. """
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
The above initialization parameters are relatively simple , Just two parameters :reduction The default is ’mean’, That is, return the average loss ,loss_weight control L1 Loss Total weight value . But in forward There are many parameters :pred and target Don't need to say more , both of them shape Should be consistent , Suppose you are dealing with bbox both of them shape by [1000,4];weight Of shape Should be and pred Of shape equally , Control the weight of each sample to the total loss ;avg_factor and reduction_override With a few , These two parameters are respectively and loss_weight and reduction The parameter hit , Never mind .
Understand the functions of the above parameters , Take a practical example to calculate :
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
# Module Computing
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]]) # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = obj(pred, target)
print(loss, 9/8)
It is found that the result is consistent with the actual manual calculation , Simply put, the calculation process : adopt torch.abs Calculate the absolute value between each element , then .mean() Method to get the final result , This is divided by the number of elements . For example, here is 2*4=8.
Take a belt weight The version of the :
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
# Module Computing
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]]) # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
# belt weight Version of : Of the last element weight =0
weight = torch.Tensor([[1,1,1,1],[1,1,1,0]]) # [2,4]
loss = obj(pred, target, weight)
print(loss, 8/8)
1.3. Internal implementation logic
Essentially, the decorator is used to realize loss Encapsulation , To put it simply, let's talk about the calling process :
1) call forward Method , Internal call l1_loss function ;
@weighted_loss
def l1_loss(pred, target):
"""L1 loss. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. Returns: torch.Tensor: Calculated loss """
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target) # The corresponding elements are subtracted
return loss
2) Meet at this time @weighted_loss Decorator , Then jump into the decorator first , Note that at this time, we do not calculate l1 loss function , mmdet/losses/losses/utils.py
def weighted_loss(loss_func):
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# Get the loss between each element
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
First of all, loss_func namely l1_loss It was packed once , That is, some more parameters are inserted into it **kwargs, Then execute at this time l1_loss, Get the relation between each element loss value .
3) The last step , perform weight_reduce_loss To get the final form of the loss (weight, reduction, avg_factor):
def reduce_loss(loss, reduction):
"""Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Return: Tensor: Reduced loss tensor. """
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
1.4. summary
Basically mmdet The calculation process of all losses is the above process , In the use of L1 Loss when , Don't worry about so many hyperparameters , direct build loss Then incoming pred and target that will do , Other parameters are basically defaulted .
summary
To be continued …
边栏推荐
- Unity publishing /build settings
- PWN Introduction (2) stack overflow Foundation
- Writing unityshader with sublimetext
- Force buckle 27 Removing Elements
- [vcs+verdi joint simulation] ~ take the counter as an example
- 中文版PyCharm改为英文版PyCharm
- 力扣704. 二分查找
- 《谁动了我的奶酪》读后感
- 力扣59. 螺旋矩阵 II
- Some problems encountered in unity steamvr
猜你喜欢

pytorch中常用损失函数总结

Pit of smoothstep node in shadergraph

Force buckle 977 Square of ordered array

Unity + hololens common basic functions

遥感图像/UDA:Curriculum-Style Local-to-Global Adaptation for Cross-Domain Remote Sensing Image Segmentat

Unity packaging failure solution

Procedural animation -- inverse kinematics of tentacles

【VCS+Verdi聯合仿真】~ 以計數器為例

力扣209. 长度最小的子数组

LXC 和 LXD 容器总结
随机推荐
JPA composite primary key usage
mmcv常用API介绍
How does unity use mapbox to implement real maps in games?
Installation and getting started with pytoch
Unity supports the platform # define instruction of script
Some problems encountered in unity steamvr
[learning notes] AssetBundle, xlua, hot update (use steps)
Unity3d realizes Google Digital Earth
Unity multiple UI page turning left and right
Very nervous. What should I do on the first day of software testing?
VFPBS上传EXCEL并保存MSSQL到数据库中
LXC 和 LXD 容器总结
Solution to the 292 week match of Li Kou
How to install win7 on AMD Ruilong CPU A320 series motherboard
Configuration and use of controllers and routes in nestjs
Oracle-数据的基本操作
Force buckle 27 Removing Elements
z-index属性在什么情况下会失效?
力扣209. 长度最小的子数组
UnityEngine. JsonUtility. The pit of fromjason()