当前位置:网站首页>Detailed explanation of the loss module of mmdet
Detailed explanation of the loss module of mmdet
2022-06-30 05:13:00 【Wu lele~】
List of articles
Preface
This introduction mmdet Of the loss function , It will be gradually expanded in the future mmdet Precautions and methods of using the loss function in .
1、mmdet Introduction to the loss function module in
1.1. Loss The Registrar of
Let's look at the code :mmdet/models/builder.py
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry
MODELS = Registry('models', parent=MMCV_MODELS) # There is one more parent Parameters , Not for the time being
BACKBONES = MODELS
NECKS = MODELS
ROI_EXTRACTORS = MODELS
SHARED_HEADS = MODELS
HEADS = MODELS
LOSSES = MODELS # Loss Registrar
DETECTORS = MODELS
here MODELS The register is also assigned to other modules , Why are subsequent operations performed in
1.2. register L1 Loss()
@LOSSES.register_module()
class L1Loss(nn.Module):
"""L1 loss. Args: reduction (str, optional): The method to reduce the loss. Options are "none", "mean" and "sum". loss_weight (float, optional): The weight of loss. """
def __init__(self, reduction='mean', loss_weight=1.0):
super(L1Loss, self).__init__()
self.reduction = reduction
self.loss_weight = loss_weight
def forward(self,
pred,
target,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function. Args: pred (torch.Tensor): Prediction box . such as [N]; target (torch.Tensor): True value . such as [N]; weight (torch.Tensor, optional): The weight of each sample ,shape = [N], Defaults to None. avg_factor (int, optional): The coefficient that controls the total loss , Action follows loss_weight Heavy .Defaults to None. reduction_override (str, optional): Action follows reduction Heavy . Defaults to None. """
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * l1_loss(
pred, target, weight, reduction=reduction, avg_factor=avg_factor)
return loss_bbox
The above initialization parameters are relatively simple , Just two parameters :reduction The default is ’mean’, That is, return the average loss ,loss_weight control L1 Loss Total weight value . But in forward There are many parameters :pred and target Don't need to say more , both of them shape Should be consistent , Suppose you are dealing with bbox both of them shape by [1000,4];weight Of shape Should be and pred Of shape equally , Control the weight of each sample to the total loss ;avg_factor and reduction_override With a few , These two parameters are respectively and loss_weight and reduction The parameter hit , Never mind .
Understand the functions of the above parameters , Take a practical example to calculate :
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
# Module Computing
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]]) # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
loss = obj(pred, target)
print(loss, 9/8)
It is found that the result is consistent with the actual manual calculation , Simply put, the calculation process : adopt torch.abs Calculate the absolute value between each element , then .mean() Method to get the final result , This is divided by the number of elements . For example, here is 2*4=8.
Take a belt weight The version of the :
import torch
from mmdet.models import build_loss
loss_bbox = dict(type='L1Loss', loss_weight=1.0)
obj = build_loss(loss_bbox)
# Module Computing
pred = torch.Tensor([[0, 2, 3, 0], [0,2,3,0]]) # [2,4]
target = torch.Tensor([[1, 1, 1, 0], [1,1,1,1]]) # [2,4]
# belt weight Version of : Of the last element weight =0
weight = torch.Tensor([[1,1,1,1],[1,1,1,0]]) # [2,4]
loss = obj(pred, target, weight)
print(loss, 8/8)
1.3. Internal implementation logic
Essentially, the decorator is used to realize loss Encapsulation , To put it simply, let's talk about the calling process :
1) call forward Method , Internal call l1_loss function ;
@weighted_loss
def l1_loss(pred, target):
"""L1 loss. Args: pred (torch.Tensor): The prediction. target (torch.Tensor): The learning target of the prediction. Returns: torch.Tensor: Calculated loss """
if target.numel() == 0:
return pred.sum() * 0
assert pred.size() == target.size()
loss = torch.abs(pred - target) # The corresponding elements are subtracted
return loss
2) Meet at this time @weighted_loss Decorator , Then jump into the decorator first , Note that at this time, we do not calculate l1 loss function , mmdet/losses/losses/utils.py
def weighted_loss(loss_func):
@functools.wraps(loss_func)
def wrapper(pred,
target,
weight=None,
reduction='mean',
avg_factor=None,
**kwargs):
# Get the loss between each element
loss = loss_func(pred, target, **kwargs)
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
return wrapper
First of all, loss_func namely l1_loss It was packed once , That is, some more parameters are inserted into it **kwargs, Then execute at this time l1_loss, Get the relation between each element loss value .
3) The last step , perform weight_reduce_loss To get the final form of the loss (weight, reduction, avg_factor):
def reduce_loss(loss, reduction):
"""Reduce loss as specified. Args: loss (Tensor): Elementwise loss tensor. reduction (str): Options are "none", "mean" and "sum". Return: Tensor: Reduced loss tensor. """
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, elementwise_mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
@mmcv.jit(derivate=True, coderize=True)
def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
# if weight is specified, apply element-wise weight
if weight is not None:
loss = loss * weight
# if avg_factor is not specified, just reduce the loss
if avg_factor is None:
loss = reduce_loss(loss, reduction)
else:
# if reduction is mean, then average the loss by avg_factor
if reduction == 'mean':
loss = loss.sum() / avg_factor
# if reduction is 'none', then do nothing, otherwise raise an error
elif reduction != 'none':
raise ValueError('avg_factor can not be used with reduction="sum"')
return loss
1.4. summary
Basically mmdet The calculation process of all losses is the above process , In the use of L1 Loss when , Don't worry about so many hyperparameters , direct build loss Then incoming pred and target that will do , Other parameters are basically defaulted .
summary
To be continued …
边栏推荐
- Unity supports the platform # define instruction of script
- Force buckle 209 Minimum length subarray
- Procedural animation -- inverse kinematics of tentacles
- Revit Secondary Development - - Project use Panel features not opened
- Unity project hosting platform plasticscm (learn to use 1)
- Parkour demo
- Unity multiple UI page turning left and right
- 力扣209. 长度最小的子数组
- Display steerable 3D model in front of unity UI
- [notes] unity Scrollview button page turning
猜你喜欢

Unity C trigonometric function, right triangle corner calculation

Unity- the camera follows the player

Virtual and pure virtual destructions

Li Kou 2049: count the number of nodes with the highest score

Unity + hololens2 performance test

中文版PyCharm改为英文版PyCharm

Pytorch的安装以及入门使用

Basic operations of Oracle data

Redis cluster concept

LxC and LXD container summary
随机推荐
2021-06-17 solve the problem of QML borderless window stretching, window jitter and flicker when stretching and shrinking
Unity2019.3.8f1 development environment configuration of hololens2
Pytorch的安装以及入门使用
Tcp/ip protocol details Volume I (Reading Guide)
Unity packaging and publishing webgl error reason exception: failed building webgl player
amd锐龙CPU A320系列主板如何安装win7
Untiy3d controls scene screenshots through external JSON files
中文版PyCharm改为英文版PyCharm
Four methods of unity ugui button binding events
Display steerable 3D model in front of unity UI
Unity3d position the model, rotate, drag and zoom around the model to obtain the center point of the model
Unity ugui text value suspended enlarged display add text background
Revit secondary development - use panel function without opening the project
Unity3d- use animator and code to control task walking
GoLand No Tests Were Run : 不能使用 fmt.Printf() <BUG>
Unity application class and data file path
Network communication problem locating steps
Records of some problems encountered during unity development (continuously updated)
Operation file file class method
PWN入门(2)栈溢出基础