当前位置:网站首页>「解析」FocalLoss 解决数据不平衡问题
「解析」FocalLoss 解决数据不平衡问题
2022-07-07 01:26:00 【ViatorSun】

FocalLoss 的出现,主要是为了解决 anchor-based (one-stage) 目标检测网络的分类问题。后面实例分割也常使用。
注意
这里是 目标检测网络的分类问题,而不是单纯的分类问题,这两者是不一样的。
区别在于,对于分配问题,一个图片一定是属于某一确定的类的;而检测任务中的分类,是有大量的anchor无目标的(可以称为负样本)。
分类任务
正常的 K类分类任务 的标签,是用一个K长度的向量作为标签,用one-hot(或者+smooth,这里先不考虑)来进行编码,最终的标签是一个形如[1,…, 0, …, 0]这样的。那么如果想要将背景分离出,自然可以想到增加一个1维,如果目标检测任务有K类,这里只要用K+1维来表示分类,其中1维代表无目标即可。对于分类任务而言,最后一般使用 softmax 来归一,使得所有类别的输出加和为1。
但是在检测任务中,对于无目标的anchor,我们并不希望最终结果加和为1,而是所有的概率输出都是0。 那么可以这样,我们将一个多分类任务看做多个二分类任务(sigmoid),针对每一个类别,我输出一个概率,如果接近0则代表非该类别,如果接近1,则代表这个anchor是该类别。
所以网络输出不需要用softmax来归一,而是对K长度向量的每一个分量进行sigmoid激活,让其输出值代表二分类的概率。对于无目标的anchor,gt中所有的分量都是0,代表属于每一类的概率是0,即标注为背景。
至此,FocalLoss解决的问题不是多分类问题,而是 多个二分类问题。
公式解析
首先看公式:只有 标签 y = 1 y=1 y=1时,公式/交叉熵才有意义, p t p_t pt即为标签为1时对应的预测值/模型分类正确的概率
p t = ( 1 − p r e d _ s i g m o i d ) ∗ t a r g e t + p r e d _ s i g m o i d ∗ ( 1 − t a r g e t ) p_t = (1 - pred\_sigmoid) * target + pred\_sigmoid * (1 - target) pt=(1−pred_sigmoid)∗target+pred_sigmoid∗(1−target)
C E ( p t ) = − α t log ( p t ) F L ( p t ) = − α t ( 1 − p t ) γ log ( p t ) F L ( p ) = { − α ( 1 − p ) γ log ( p ) , i f y = 1 − ( 1 − α ) p γ log ( 1 − p ) , i f y = 0 CE(p_t)=-\alpha_t \log(p_t) \\ \quad \\ FL(p_t)=-\alpha_t(1-p_t)^\gamma \log(p_t) \\ \quad \\ FL(p) = \begin{cases} \quad -\alpha(1-p)^\gamma \log(p) &, if \quad y=1 &\\ -(1-\alpha)p^\gamma \log(1-p)&,if \quad y=0 \end{cases} CE(pt)=−αtlog(pt)FL(pt)=−αt(1−pt)γlog(pt)FL(p)={ −α(1−p)γlog(p)−(1−α)pγlog(1−p),ify=1,ify=0
- 参数p[公式3]:当 p->0时(概率很低/很难区分是那个类别),调制因子 (1-p)接近1,损失不被影响,当 p->1时,(1-p)接近0,从而减小易分样本对总 loss的贡献
- 参数 γ \gamma γ:当 γ = 0 \gamma=0 γ=0 时,Focal loss就是传统的交叉熵,
当 γ \gamma γ 增加时, 调节系数 ( 1 − p t ) (1-p_t) (1−pt) 也会增加。
当 γ \gamma γ 为定值时,比如 γ = 2 \gamma=2 γ=2 ️对于easy example(p>0.5) p=0.9 的loss要比标准的交叉熵小 100倍,当 p=0.968时,要小1000+倍;️对于 hard example(p<0.5) loss要小4倍
这样的话, hard example 的权重相对提升了很多,从而增加了哪些误分类的重要性。
实验表明, γ = 2 , α = 0.75 \gamma=2,\alpha=0.75 γ=2,α=0.75时效果最好 - α \alpha α 调节正负样本不平衡系数, γ \gamma γ 控制难易样本不平衡
代码复现
在官方给的代码中,并没有 target = F.one_hot(target, num_clas) 这行代码,这是因为
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import torch
from torch.nn import functional as F
def sigmoid_focal_loss( inputs: torch.Tensor, targets: torch.Tensor, alpha: float = -1,
gamma: float = 2, reduction: str = "none") -> torch.Tensor:
inputs = inputs.float()
targets = targets.float()
p = torch.sigmoid(inputs)
target = F.one_hot(target, num_clas+1)
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = p * targets + (1 - p) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
if reduction == "mean":
loss = loss.mean()
elif reduction == "sum":
loss = loss.sum()
return loss
sigmoid_focal_loss_jit: "torch.jit.ScriptModule" = torch.jit.script(sigmoid_focal_loss)
此外,torchvision 中也支持 focal loss
完整代码
官方完整代码:https://github.com/facebookresearch/
参考
- https://zhuanlan.zhihu.com/p/391186824
边栏推荐
- Flask1.1.4 werkzeug1.0.1 source code analysis: start the process
- k8s运行oracle
- EMMC print cqhci: timeout for tag 10 prompt analysis and solution
- What EDA companies are there in China?
- 职场经历反馈给初入职场的程序员
- 从“跑分神器”到数据平台,鲁大师开启演进之路
- [SQL practice] a SQL statistics of epidemic distribution across the country
- 绕过open_basedir
- Understand the deserialization principle of fastjson for generics
- Mac version PHP installed Xdebug environment (M1 version)
猜你喜欢

软件测试的几个关键步骤,你需要知道

Convert numbers to string strings (to_string()) convert strings to int sharp tools stoi();

职场经历反馈给初入职场的程序员

JVM监控及诊断工具-命令行篇

Dc-7 target

Reading notes of Clickhouse principle analysis and Application Practice (6)

Rk3399 platform development series explanation (WiFi) 5.52. Introduction to WiFi framework composition

JVM命令之 jstat:查看JVM統計信息

Apple CMS V10 template /mxone Pro adaptive film and television website template

Markdown 并排显示图片
随机推荐
一个简单的代数问题的求解
Career experience feedback to novice programmers
Say sqlyog deceived me!
New Year Fireworks code plus copy, are you sure you don't want to have a look
Rk3399 platform development series explanation (WiFi) 5.52. Introduction to WiFi framework composition
PTA TIANTI game exercise set l2-003 moon cake test point 2, test point 3 Analysis
[InstallShield] Introduction
为不同类型设备构建应用的三大更新 | 2022 I/O 重点回顾
基于ADAU1452的DSP及DAC音频失真分析
go-microservice-simple(2) go-Probuffer
Industrial Finance 3.0: financial technology of "dredging blood vessels"
Experience of Niuke SQL
EMMC print cqhci: timeout for tag 10 prompt analysis and solution
苹果cms V10模板/MXone Pro自适应影视电影网站模板
SAP Spartacus checkout 流程的扩展(extend)实现介绍
Jstat of JVM command: View JVM statistics
Sequential storage of stacks
Introduction to yarn (one article is enough)
If you don't know these four caching modes, dare you say you understand caching?
Flask 1.1.4 werkzeug1.0.1 analyse du code source: processus de démarrage