当前位置:网站首页>1、 Focal loss theory and code implementation
1、 Focal loss theory and code implementation
2022-07-29 06:08:00 【My hair is messy】
List of articles
Preface
In this paper, the reference : When can I see Qingmeng blogger's article
Reference to the original :https://www.jianshu.com/p/30043bcc90b6
One 、 The basic theory
1. use soft - gamma: Increase periodically in the process of training gamma There may be better performance improvement .
2.alpha It is related to the frequency of each category in the training data .
3.F.nll_loss(torch.log(F.softmax(inputs, dim=1),target) The function function of is similar to F.cross_entropy identical .
F.nll_loss For target Of one-hot encoding, Encode it as input shape same tensor, Then compare with the previous one ( namely F.nll_loss First item entered ) Conduct element-wise production.
be based on alpha=1 Use different gamma Value the result of the experiment
4.focal loss What problems have been solved ?
(1) Different categories are uneven
(2) Difficult and easy sample imbalance
5. stay retinanet in , Besides using focal loss Outside , Special processing is also done for initialization , How to do it ?
stay retinanet in , Yes classification subnet The last floor of conv Set its offset b by :
Two 、 Realization
1. The formula
The standard Cross Entropy and Focal Loss by :
See Zhihu for the forward and backward derivation of :https://zhuanlan.zhihu.com/p/32631517
2. Code implementation
1. Based on binary classification cross entropy .
# 1. Based on binary classification cross entropy
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.logits = logits
self.reduce = reduce
def forward(self, inputs, targets):
if self.logits:
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
else:
BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduce:
return torch.mean(F_loss)
else:
return F_loss
2. The realization of Zhihu boss
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
class FocalLoss(nn.Module):
r"""
This criterion is a implemenation of Focal Loss, which is proposed in
Focal Loss for Dense Object Detection.
Loss(x, class) = - \alpha (1-softmax(x)[class])^gamma \log(softmax(x)[class])
The losses are averaged across observations for each minibatch.
Args:
alpha(1D Tensor, Variable) : the scalar factor for this criterion
gamma(float, double) : gamma > 0; reduces the relative loss for well-classified examples (p > .5),
putting more focus on hard, misclassified examples
size_average(bool): By default, the losses are averaged over observations for each minibatch.
However, if the field size_average is set to False, the losses are
instead summed for each minibatch.
"""
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
super(FocalLoss, self).__init__()
if alpha is None:
self.alpha = Variable(torch.ones(class_num, 1))
else:
if isinstance(alpha, Variable):
self.alpha = alpha
else:
self.alpha = Variable(alpha)
self.gamma = gamma
self.class_num = class_num
self.size_average = size_average
def forward(self, inputs, targets):
N = inputs.size(0)
C = inputs.size(1)
P = F.softmax(inputs)
class_mask = inputs.data.new(N, C).fill_(0)
class_mask = Variable(class_mask)
ids = targets.view(-1, 1)
class_mask.scatter_(1, ids.data, 1.)
#print(class_mask)
if inputs.is_cuda and not self.alpha.is_cuda:
self.alpha = self.alpha.cuda()
alpha = self.alpha[ids.data.view(-1)]
probs = (P*class_mask).sum(1).view(-1,1)
log_p = probs.log()
#print('probs size= {}'.format(probs.size()))
#print(probs)
batch_loss = -alpha*(torch.pow((1-probs), self.gamma))*log_p
#print('-----bacth_loss------')
#print(batch_loss)
if self.size_average:
loss = batch_loss.mean()
else:
loss = batch_loss.sum()
return loss
``
边栏推荐
- ABSA1: Attentional Encoder Network for Targeted Sentiment Classification
- 【Transformer】TransMix: Attend to Mix for Vision Transformers
- 三、如何读取视频?
- [image classification] how to use mmclassification to train your classification model
- 【语义分割】语义分割综述
- [DL] introduction and understanding of tensor
- 电脑视频暂停再继续,声音突然变大
- 【bug】XLRDError: Excel xlsx file; not supported
- Flink, the mainstream real-time stream processing computing framework, is the first experience.
- Technology that deeply understands the principle of MMAP and makes big manufacturers love it
猜你喜欢
【Transformer】ACMix:On the Integration of Self-Attention and Convolution
一、迁移学习与fine-tuning有什么区别?
【Transformer】SegFormer:Simple and Efficient Design for Semantic Segmentation with Transformers
迁移学习——Low-Rank Transfer Subspace Learning
迁移学习—Geodesic Flow Kernel for Unsupervised Domain Adaptation
ROS常用指令
ROS教程(Xavier)
[clustmaps] visitor statistics
Yum local source production
虚假新闻检测论文阅读(五):A Semi-supervised Learning Method for Fake News Detection in Social Media
随机推荐
神经网络相关知识回顾(PyTorch篇)
[clustmaps] visitor statistics
How to perform POC in depth with full flash distribution?
[image classification] how to use mmclassification to train your classification model
备份谷歌或其他浏览器插件
Spring, summer, autumn and winter with Miss Zhang (1)
【DL】搭建卷积神经网络用于回归预测(数据+代码详细教程)
Wechat applet source code acquisition (download with tools)
Wechat built-in browser prohibits caching
迁移学习——Transitive Transfer Learning
Chongqing Avenue cloud bank, as a representative of the software industry, was invited to participate in the signing ceremony of key projects in Yuzhong District
一、常见损失函数的用法
有价值的博客、面经收集(持续更新)
[CV] what are the specific numbers of convolution kernels (filters) 3*3, 5*5, 7*7 and 11*11?
PyTorch中的模型构建
Anr Optimization: cause oom crash and corresponding solutions
一、Focal Loss理论及代码实现
【Attention】Visual Attention Network
三、如何搞自定义数据集?
【Transformer】SOFT: Softmax-free Transformer with Linear Complexity