当前位置:网站首页>Pytorch:交叉熵损失(CrossEntropyLoss)以及标签平滑(LabelSmoothing)的实现
Pytorch:交叉熵损失(CrossEntropyLoss)以及标签平滑(LabelSmoothing)的实现
2022-07-28 17:23:00 【我是大黄同学呀】
0. 前言
一般情况下我们都是直接调用Pytorch自带的交叉熵损失函数计算loss,但涉及到魔改以及优化时,我们需要自己动手实现loss function,在这个过程中如果能对交叉熵损失的代码实现有一定的了解会帮助我们写出更优美的代码。
其次是标签平滑这个trick通常简单有效,只需要改改损失函数既可带来性能上的提升,通常与交叉熵配合食用。
因此,本文基于这两个出发点,介绍基于Pytorch框架下的交叉熵损失实现以及标签平滑的实现。
1. 浅谈CrossEntropyLoss
相信大家对于如何计算交叉熵已经非常熟悉,常规步骤是①计算softmax得到各类别置信度;②计算交叉熵损失。但其实从Pytorch的官方文档可以看出,还有更一步到位的方法,如下:

这避免了softmax的计算。
代码实现
很简单,根据公式写代码就好了
class CELoss(nn.Module):
''' Cross Entropy Loss'''
def __init__(self):
super().__init__()
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
2. 浅谈Label Smoothing
Label Smoothing也称之为标签平滑,其实是一种防止过拟合的正则化方法。传统的分类loss采用softmax loss,先对全连接层的输出计算softmax,视为各类别的置信度概率,再利用交叉熵计算损失。

在这个过程中尽可能使得各样本在正确类别上的输出概率为1,这要使得对应的z值为+∞,这拉大了其与其他类别间的距离。
现在假设一个多分类任务标签是[1,0,0],如果它本身的label的出现了问题,这对模型的伤害是非常大的,因为在训练的过程中强行学习一个非本类的样本,并且让其概率非常高,这会影响对后验概率的估计。并且有时候类与类之间的并不是毫无关联,如果鼓励输出的概率间相差过大,这会导致一定程度上的过拟合。
因此Label Smoothing的想法是让目标不再是one-hot标签,而是变为如下形式:

其中ε为一个较小的常数,这使得softmax损失中的概率优目标不再为1和0,同时z值的最优解也不再是正无穷大,而是一个具体的数值。这在一定程度上避免了过拟合,也缓解了错误标签带来的影响。
代码实现
基于上一节的交叉熵实现增加标签平滑功能,代码如下:
class CELoss(nn.Module):
''' Cross Entropy Loss with label smoothing '''
def __init__(self, label_smooth=None, class_num=137):
super().__init__()
self.label_smooth = label_smooth
self.class_num = class_num
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
if self.label_smooth is not None:
# cross entropy loss with label smoothing
logprobs = F.log_softmax(pred, dim=1) # softmax + log
target = F.one_hot(target, self.class_num) # 转换成one-hot
# label smoothing
# 实现 1
# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
# 实现 2
# implement 2
target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
loss = -1*torch.sum(target*logprobs, 1)
else:
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
实现1采用了
(1.0-self.label_smooth)*target + self.label_smooth/self.class_num实现,与原始公式不太一样。后续在了解到pytorch的clamp接口后,发现能够利用其能正确实现原公式,见实现2。
3. 实验验证
① 交叉熵损失正确率,与标准的交叉熵比较:
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=None, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.0018)
② 标签平滑结果展示:
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=0.05, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.2352)
另一组结果:
x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0007) tensor(0.2641)
分析:拉大模型输出数值间的差距后,原始的交叉熵会变小,而增加了标签平滑的反而变大。这也反映了标签平滑后,并不是概率越接近于1越好,而是接近某个小于1的值,这使得模型的输出不再是越高(+∞)越好。
边栏推荐
- Learn from Li Mu, deep learning - linear regression and basic optimization function
- 2022年中总结
- 2022年最火的十大测试工具,你掌握了几个
- 优麒麟系统安装BeyondComare
- 3、 Uni app fixed or direct to a certain page
- 图书管理数据库系统设计
- cv5200无线WiFi通信模块,视频图像传输无线化,实时无线通信技术
- An intern's journey to cnosdb
- Minio distributed file system learning notes
- What if the content of software testing is too simple?
猜你喜欢

三类6种地图可视化软件测评,最好用的工具居然是它

R language text mining and natural language processing tutorial

Swiftui component how to implement textfield of hidden part of phone number mask (tutorial includes source code)

Getting started with QT & OpenGL
![[R language - basic drawing]](/img/1e/aebf1cbe02c4574671bac6dc2c9171.png)
[R language - basic drawing]

FTM module of K60: configure motor, encoder and steering gear

From Bayesian filter to Kalman filter (I)

SRS4.0安装步骤

2022年暑假ACM热身练习3(详细)

一家芯片公司倒在了B轮
随机推荐
Can I get employed after two months of software testing training?
Module 8 of the construction camp
How to write a JMeter script common to the test team
QT running image
顺序线性表——课上练
服务器正文21:不同编译器对预编译的处理(简单介绍msvc和gcc)
2022年最火的十大测试工具,你掌握了几个
UWB module realizes personnel precise positioning, ultra wideband pulse technology scheme, and real-time centimeter level positioning application
cv5200无线WiFi通信模块,视频图像传输无线化,实时无线通信技术
Getting started with QT & OpenGL
机器学习 --- 模型评估、选择与验证
How to adjust the brightness of win11? Four methods of adjusting screen brightness in win11
Wechat official account custom sharing and updateappmessagesharedata are all successful. Why is it a link that is shared?
CVPR19 - 调参干货《Bag of Tricks for Image Classification with Convolutional Neural Network》
Self-adaptive multi-objective evolutionary algorithm for flexible job shop scheduling with fuzzy pro
vim学习手册
2022年中总结
Lookup - lookup of sequential table and ordered table
Cv5200 wireless WiFi communication module, wireless video image transmission, real-time wireless communication technology
Server body 21: pre compilation processing by different compilers (a brief introduction to MSVC and GCC)