当前位置:网站首页>「解析」FocalLoss 解决数据不平衡问题
「解析」FocalLoss 解决数据不平衡问题
2022-07-07 01:26:00 【ViatorSun】

FocalLoss 的出现,主要是为了解决 anchor-based (one-stage) 目标检测网络的分类问题。后面实例分割也常使用。
注意
这里是 目标检测网络的分类问题,而不是单纯的分类问题,这两者是不一样的。
区别在于,对于分配问题,一个图片一定是属于某一确定的类的;而检测任务中的分类,是有大量的anchor无目标的(可以称为负样本)。
分类任务
正常的 K类分类任务 的标签,是用一个K长度的向量作为标签,用one-hot(或者+smooth,这里先不考虑)来进行编码,最终的标签是一个形如[1,…, 0, …, 0]这样的。那么如果想要将背景分离出,自然可以想到增加一个1维,如果目标检测任务有K类,这里只要用K+1维来表示分类,其中1维代表无目标即可。对于分类任务而言,最后一般使用 softmax 来归一,使得所有类别的输出加和为1。
但是在检测任务中,对于无目标的anchor,我们并不希望最终结果加和为1,而是所有的概率输出都是0。 那么可以这样,我们将一个多分类任务看做多个二分类任务(sigmoid),针对每一个类别,我输出一个概率,如果接近0则代表非该类别,如果接近1,则代表这个anchor是该类别。
所以网络输出不需要用softmax来归一,而是对K长度向量的每一个分量进行sigmoid激活,让其输出值代表二分类的概率。对于无目标的anchor,gt中所有的分量都是0,代表属于每一类的概率是0,即标注为背景。
至此,FocalLoss解决的问题不是多分类问题,而是 多个二分类问题。
公式解析
首先看公式:只有 标签 y = 1 y=1 y=1时,公式/交叉熵才有意义, p t p_t pt即为标签为1时对应的预测值/模型分类正确的概率
p t = ( 1 − p r e d _ s i g m o i d ) ∗ t a r g e t + p r e d _ s i g m o i d ∗ ( 1 − t a r g e t ) p_t = (1 - pred\_sigmoid) * target + pred\_sigmoid * (1 - target) pt=(1−pred_sigmoid)∗target+pred_sigmoid∗(1−target)
C E ( p t ) = − α t log ( p t ) F L ( p t ) = − α t ( 1 − p t ) γ log ( p t ) F L ( p ) = { − α ( 1 − p ) γ log ( p ) , i f y = 1 − ( 1 − α ) p γ log ( 1 − p ) , i f y = 0 CE(p_t)=-\alpha_t \log(p_t) \\ \quad \\ FL(p_t)=-\alpha_t(1-p_t)^\gamma \log(p_t) \\ \quad \\ FL(p) = \begin{cases} \quad -\alpha(1-p)^\gamma \log(p) &, if \quad y=1 &\\ -(1-\alpha)p^\gamma \log(1-p)&,if \quad y=0 \end{cases} CE(pt)=−αtlog(pt)FL(pt)=−αt(1−pt)γlog(pt)FL(p)={ −α(1−p)γlog(p)−(1−α)pγlog(1−p),ify=1,ify=0
- 参数p[公式3]:当 p->0时(概率很低/很难区分是那个类别),调制因子 (1-p)接近1,损失不被影响,当 p->1时,(1-p)接近0,从而减小易分样本对总 loss的贡献
- 参数 γ \gamma γ:当 γ = 0 \gamma=0 γ=0 时,Focal loss就是传统的交叉熵,
当 γ \gamma γ 增加时, 调节系数 ( 1 − p t ) (1-p_t) (1−pt) 也会增加。
当 γ \gamma γ 为定值时,比如 γ = 2 \gamma=2 γ=2 ️对于easy example(p>0.5) p=0.9 的loss要比标准的交叉熵小 100倍,当 p=0.968时,要小1000+倍;️对于 hard example(p<0.5) loss要小4倍
这样的话, hard example 的权重相对提升了很多,从而增加了哪些误分类的重要性。
实验表明, γ = 2 , α = 0.75 \gamma=2,\alpha=0.75 γ=2,α=0.75时效果最好 - α \alpha α 调节正负样本不平衡系数, γ \gamma γ 控制难易样本不平衡
代码复现
在官方给的代码中,并没有 target = F.one_hot(target, num_clas) 这行代码,这是因为
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1,
gamma: float = 2, reduction: str = "none") -> torch.Tensor:
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
target = F.one_hot(target, num_clas+1)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss)
此外,torchvision 中也支持 focal loss
完整代码
官方完整代码:https://github.com/facebookresearch/
参考
- https://zhuanlan.zhihu.com/p/391186824
边栏推荐
- JVM监控及诊断工具-命令行篇
- ML之shap:基于adult人口普查收入二分类预测数据集(预测年收入是否超过50k)利用shap决策图结合LightGBM模型实现异常值检测案例之详细攻略
- 【GNN】图解GNN: A gentle introduction(含视频)
- 搞懂fastjson 对泛型的反序列化原理
- 那些自损八百的甲方要求
- Storage of dental stem cells (to be continued)
- MySQL performance_ Schema common performance diagnosis query
- Senior programmers must know and master. This article explains in detail the principle of MySQL master-slave synchronization, and recommends collecting
- Go language learning notes - Gorm use - Gorm processing errors | web framework gin (10)
- Rk3399 platform development series explanation (interruption) 13.10, workqueue work queue
猜你喜欢

【GNN】图解GNN: A gentle introduction(含视频)

Convert numbers to string strings (to_string()) convert strings to int sharp tools stoi();

Mac version PHP installed Xdebug environment (M1 version)

Financial risk control practice - decision tree rule mining template

Jmeter自带函数不够用?不如自己动手开发一个
![A freshman's summary of an ordinary student [I don't know whether we are stupid or crazy, but I know to run forward all the way]](/img/fd/7223d78fff54c574260ec0da5f41d5.png)
A freshman's summary of an ordinary student [I don't know whether we are stupid or crazy, but I know to run forward all the way]

Nvisual network visualization

JVM监控及诊断工具-命令行篇

On the discrimination of "fake death" state of STC single chip microcomputer

Peripheral driver library development notes 43: GPIO simulation SPI driver
随机推荐
Talking about reading excel with POI
[daily training -- Tencent selected 50] 292 Nim games
关于STC单片机“假死”状态的判别
3531. 哈夫曼树
Go language learning notes - Gorm use - Gorm processing errors | web framework gin (10)
Jmeter自带函数不够用?不如自己动手开发一个
yarn入门(一篇就够了)
Go语学习笔记 - gorm使用 - 原生sql、命名参数、Rows、ToSQL | Web框架Gin(九)
Opensergo is about to release v1alpha1, which will enrich the service governance capabilities of the full link heterogeneous architecture
职场经历反馈给初入职场的程序员
Experience of Niuke SQL
你不知道的互联网公司招聘黑话大全
Redisl garbled code and expiration time configuration
改变ui组件原有样式
Why does the data center need a set of infrastructure visual management system
Find duplicate email addresses
Data storage 3
How to improve website weight
go-microservice-simple(2) go-Probuffer
Peripheral driver library development notes 43: GPIO simulation SPI driver