当前位置:网站首页>Pytoch: implementation of crossentropyloss and labelsmoothing
Pytoch: implementation of crossentropyloss and labelsmoothing
2022-07-28 19:26:00 【I'm Mr. rhubarb】
List of articles
0. Preface
Generally, we call directly Pytorch Calculation of the self-contained cross entropy loss function loss, But when it comes to magic reform and optimization , We need to do it ourselves loss function, In this process, if we can have a certain understanding of the code implementation of cross entropy loss, it will help us write more beautiful code .
The second is label smoothing trick Usually simple and effective , Just changing the loss function can improve the performance , Usually eaten with cross entropy .
therefore , This paper is based on these two starting points , The introduction is based on Pytorch The realization of cross entropy loss and label smoothing under the framework .
1. Talking about CrossEntropyLoss
I believe you are very familiar with how to calculate cross entropy , The normal procedure is ① Calculation softmax Get various kinds of confidence ;② Calculate the cross entropy loss . But actually from Pytorch According to the official documents of , There are more ways to achieve this , as follows :

This avoids softmax The calculation of .
Code implementation
It's simple , Just write code according to the formula
class CELoss(nn.Module):
''' Cross Entropy Loss'''
def __init__(self):
super().__init__()
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
2. Talking about Label Smoothing
Label Smoothing Also known as label smoothing , In fact, it is a regularization method to prevent over fitting . Traditional classification loss use softmax loss, First, calculate the output of the full connection layer softmax, It is regarded as the confidence probability of each category , Then cross entropy is used to calculate the loss .

In this process, try to make the output probability of each sample in the correct category as 1, This makes the corresponding z The value is +∞, This widens its distance from other categories .
Now suppose a multi category task tag is [1,0,0], If its own label There's a problem with , This is very harmful to the model , Because in the process of training, force to learn a sample that is not this kind , And let its probability be very high , This will affect the estimation of a posteriori probability . And sometimes the relationship between classes is not irrelevant , If the difference between the probability of encouraging output is too large , This will lead to a certain degree of over fitting .
therefore Label Smoothing The idea is to make the goal no longer one-hot label , Instead, it becomes the following form :

among ε Is a smaller constant , This makes softmax The probability optimal target in loss is no longer 1 and 0, meanwhile z The optimal solution of value is no longer positive infinity , But a specific value . This avoids over fitting to some extent , It also alleviates the impact of wrong labels .
Code implementation
Based on the cross entropy in the previous section, the tag smoothing function is added , The code is as follows :
class CELoss(nn.Module):
''' Cross Entropy Loss with label smoothing '''
def __init__(self, label_smooth=None, class_num=137):
super().__init__()
self.label_smooth = label_smooth
self.class_num = class_num
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
if self.label_smooth is not None:
# cross entropy loss with label smoothing
logprobs = F.log_softmax(pred, dim=1) # softmax + log
target = F.one_hot(target, self.class_num) # convert to one-hot
# label smoothing
# Realization 1
# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
# Realization 2
# implement 2
target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
loss = -1*torch.sum(target*logprobs, 1)
else:
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
Realization 1 Adopted
(1.0-self.label_smooth)*target + self.label_smooth/self.class_numRealization , It is different from the original formula .Later, I learned pytorch Of clamp After the interface , It is found that it can be used to correctly realize the original formula , see Realization 2.
3. Experimental verification
① Cross entropy loss accuracy , Compare with the standard cross entropy :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=None, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.0018)
② Label smoothing results display :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=0.05, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.2352)
Another set of results :
x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0007) tensor(0.2641)
analysis : After widening the gap between the model output values , The original cross entropy will decrease , However, the smoother labels become larger . This also reflects the label after smoothing , It's not that the probability is closer to 1 The better , But close to something less than 1 Value , This makes the output of the model no longer higher (+∞) The better .
边栏推荐
猜你喜欢

DevCon.exe 导出output至指定文件

VAE: understanding and implementation of variational self encoder

Leetcode skimming - super power 372 medium

Qt: one signal binds multiple slots

ES6 new - arrow function

三类6种地图可视化软件测评,最好用的工具居然是它

用于异常检测的Transformer - InTra《Inpainting Transformer for Anomaly Detection》

pytest 自定义HOOK函数

Validate hardware DDR design with Xilinx MIG

使用Xilinx MIG验证硬件DDR设计
随机推荐
FTM module of K60: configure motor, encoder and steering gear
VAE: understanding and implementation of variational self encoder
[physical application] atmospheric absorption loss with matlab code
ES6's new data container map
More loading in applets (i.e. list paging)
关于白盒测试,这些技巧你得游刃有余~
用于异常检测的Transformer - InTra《Inpainting Transformer for Anomaly Detection》
Qt: one signal binds multiple slots
Srs4.0 installation steps
From Bayesian filter to Kalman filter (zero)
Minio distributed file system learning notes
MES生产管理系统对设备的应用价值
Photoshop web design practical tutorial
ICLR21(classification) - 未来经典“ViT” 《AN IMAGE IS WORTH 16X16 WORDS》(含代码分析)
Avoidance Adjusted Climbrate
RFs self study notes (III): clutter model - first determine the number with Poisson distribution, and then use uniform distribution as probability distribution
Accumulation and development -- the way of commercialization of open source companies
这种动态规划你见过吗——状态机动态规划之股票问题(下)
From Bayesian filter to Kalman filter (2)
SaltStack之数据系统