当前位置:网站首页>Pytoch: implementation of crossentropyloss and labelsmoothing
Pytoch: implementation of crossentropyloss and labelsmoothing
2022-07-28 19:26:00 【I'm Mr. rhubarb】
List of articles
0. Preface
Generally, we call directly Pytorch Calculation of the self-contained cross entropy loss function loss, But when it comes to magic reform and optimization , We need to do it ourselves loss function, In this process, if we can have a certain understanding of the code implementation of cross entropy loss, it will help us write more beautiful code .
The second is label smoothing trick Usually simple and effective , Just changing the loss function can improve the performance , Usually eaten with cross entropy .
therefore , This paper is based on these two starting points , The introduction is based on Pytorch The realization of cross entropy loss and label smoothing under the framework .
1. Talking about CrossEntropyLoss
I believe you are very familiar with how to calculate cross entropy , The normal procedure is ① Calculation softmax Get various kinds of confidence ;② Calculate the cross entropy loss . But actually from Pytorch According to the official documents of , There are more ways to achieve this , as follows :

This avoids softmax The calculation of .
Code implementation
It's simple , Just write code according to the formula
class CELoss(nn.Module):
''' Cross Entropy Loss'''
def __init__(self):
super().__init__()
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
2. Talking about Label Smoothing
Label Smoothing Also known as label smoothing , In fact, it is a regularization method to prevent over fitting . Traditional classification loss use softmax loss, First, calculate the output of the full connection layer softmax, It is regarded as the confidence probability of each category , Then cross entropy is used to calculate the loss .

In this process, try to make the output probability of each sample in the correct category as 1, This makes the corresponding z The value is +∞, This widens its distance from other categories .
Now suppose a multi category task tag is [1,0,0], If its own label There's a problem with , This is very harmful to the model , Because in the process of training, force to learn a sample that is not this kind , And let its probability be very high , This will affect the estimation of a posteriori probability . And sometimes the relationship between classes is not irrelevant , If the difference between the probability of encouraging output is too large , This will lead to a certain degree of over fitting .
therefore Label Smoothing The idea is to make the goal no longer one-hot label , Instead, it becomes the following form :

among ε Is a smaller constant , This makes softmax The probability optimal target in loss is no longer 1 and 0, meanwhile z The optimal solution of value is no longer positive infinity , But a specific value . This avoids over fitting to some extent , It also alleviates the impact of wrong labels .
Code implementation
Based on the cross entropy in the previous section, the tag smoothing function is added , The code is as follows :
class CELoss(nn.Module):
''' Cross Entropy Loss with label smoothing '''
def __init__(self, label_smooth=None, class_num=137):
super().__init__()
self.label_smooth = label_smooth
self.class_num = class_num
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
if self.label_smooth is not None:
# cross entropy loss with label smoothing
logprobs = F.log_softmax(pred, dim=1) # softmax + log
target = F.one_hot(target, self.class_num) # convert to one-hot
# label smoothing
# Realization 1
# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
# Realization 2
# implement 2
target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
loss = -1*torch.sum(target*logprobs, 1)
else:
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
Realization 1 Adopted
(1.0-self.label_smooth)*target + self.label_smooth/self.class_numRealization , It is different from the original formula .Later, I learned pytorch Of clamp After the interface , It is found that it can be used to correctly realize the original formula , see Realization 2.
3. Experimental verification
① Cross entropy loss accuracy , Compare with the standard cross entropy :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=None, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.0018)
② Label smoothing results display :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=0.05, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.2352)
Another set of results :
x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0007) tensor(0.2641)
analysis : After widening the gap between the model output values , The original cross entropy will decrease , However, the smoother labels become larger . This also reflects the label after smoothing , It's not that the probability is closer to 1 The better , But close to something less than 1 Value , This makes the output of the model no longer higher (+∞) The better .
边栏推荐
- Wechat solves the problem of long press selected style
- Using CPLEX to solve small-scale energy-efficient FJSP
- JDBC简单封装
- VAE:变分自编码器的理解与实现
- Learn from Li Mu, deep learning - linear regression and basic optimization function
- Get to know nodejs for the first time (with cases)
- VIM learning manual
- From Bayesian filter to Kalman filter (2)
- Photoshop web design practical tutorial
- Share several coding code receiving verification code platforms, which will be updated in February 2022
猜你喜欢

SQL audit tool self introduction owls

软件测试开发基础|测开中的几个工具开发实战

Application of time series database in monitoring operation and maintenance platform

CVPR21-无监督异常检测《CutPaste:Self-Supervised Learning for Anomaly Detection and Localization》

Parity rearrangement of Bm14 linked list

FTM module of K60: configure motor, encoder and steering gear

Application of time series database in museum environment detection

Leetcode skimming - super power 372 medium

Srs4.0 installation steps

Pytest custom hook function
随机推荐
Efficiency comparison of JS array splicing push() concat() methods
C语言循环语句强化练习题
Pytest custom hook function
Learn from Li Mu in depth -softmax return
用LEX(FLEX)生成PL语言的词法分析器
Leetcode skimming - super power 372 medium
[physical application] atmospheric absorption loss with matlab code
这种动态规划你见过吗——状态机动态规划之股票问题(下)
JS modify table font and table border style
VAE: understanding and implementation of variational self encoder
C string to short[] method
Cvpr21 unsupervised anomaly detection cutpaste:self supervised learning for anomaly detection and localization
Swing事件处理的过程是怎样的?
MES生产管理系统对设备的应用价值
BLDC 6步换相 simulink
Pytoch: quickly find the main diagonal elements and non diagonal elements of NxN matrix
After several twists and turns, how long can the TSDB C-bit of influxdb last?
Image processing web application development tutorial
图书管理数据库系统设计
How to use Qianqian listening sound effect plug-in (fierce Classic)