当前位置:网站首页>Pytoch: implementation of crossentropyloss and labelsmoothing
Pytoch: implementation of crossentropyloss and labelsmoothing
2022-07-28 19:26:00 【I'm Mr. rhubarb】
List of articles
0. Preface
Generally, we call directly Pytorch Calculation of the self-contained cross entropy loss function loss, But when it comes to magic reform and optimization , We need to do it ourselves loss function, In this process, if we can have a certain understanding of the code implementation of cross entropy loss, it will help us write more beautiful code .
The second is label smoothing trick Usually simple and effective , Just changing the loss function can improve the performance , Usually eaten with cross entropy .
therefore , This paper is based on these two starting points , The introduction is based on Pytorch The realization of cross entropy loss and label smoothing under the framework .
1. Talking about CrossEntropyLoss
I believe you are very familiar with how to calculate cross entropy , The normal procedure is ① Calculation softmax Get various kinds of confidence ;② Calculate the cross entropy loss . But actually from Pytorch According to the official documents of , There are more ways to achieve this , as follows :

This avoids softmax The calculation of .
Code implementation
It's simple , Just write code according to the formula
class CELoss(nn.Module):
''' Cross Entropy Loss'''
def __init__(self):
super().__init__()
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
2. Talking about Label Smoothing
Label Smoothing Also known as label smoothing , In fact, it is a regularization method to prevent over fitting . Traditional classification loss use softmax loss, First, calculate the output of the full connection layer softmax, It is regarded as the confidence probability of each category , Then cross entropy is used to calculate the loss .

In this process, try to make the output probability of each sample in the correct category as 1, This makes the corresponding z The value is +∞, This widens its distance from other categories .
Now suppose a multi category task tag is [1,0,0], If its own label There's a problem with , This is very harmful to the model , Because in the process of training, force to learn a sample that is not this kind , And let its probability be very high , This will affect the estimation of a posteriori probability . And sometimes the relationship between classes is not irrelevant , If the difference between the probability of encouraging output is too large , This will lead to a certain degree of over fitting .
therefore Label Smoothing The idea is to make the goal no longer one-hot label , Instead, it becomes the following form :

among ε Is a smaller constant , This makes softmax The probability optimal target in loss is no longer 1 and 0, meanwhile z The optimal solution of value is no longer positive infinity , But a specific value . This avoids over fitting to some extent , It also alleviates the impact of wrong labels .
Code implementation
Based on the cross entropy in the previous section, the tag smoothing function is added , The code is as follows :
class CELoss(nn.Module):
''' Cross Entropy Loss with label smoothing '''
def __init__(self, label_smooth=None, class_num=137):
super().__init__()
self.label_smooth = label_smooth
self.class_num = class_num
def forward(self, pred, target):
''' Args: pred: prediction of model output [N, M] target: ground truth of sampler [N] '''
eps = 1e-12
if self.label_smooth is not None:
# cross entropy loss with label smoothing
logprobs = F.log_softmax(pred, dim=1) # softmax + log
target = F.one_hot(target, self.class_num) # convert to one-hot
# label smoothing
# Realization 1
# target = (1.0-self.label_smooth)*target + self.label_smooth/self.class_num
# Realization 2
# implement 2
target = torch.clamp(target.float(), min=self.label_smooth/(self.class_num-1), max=1.0-self.label_smooth)
loss = -1*torch.sum(target*logprobs, 1)
else:
# standard cross entropy loss
loss = -1.*pred.gather(1, target.unsqueeze(-1)) + torch.log(torch.exp(pred+eps).sum(dim=1))
return loss.mean()
Realization 1 Adopted
(1.0-self.label_smooth)*target + self.label_smooth/self.class_numRealization , It is different from the original formula .Later, I learned pytorch Of clamp After the interface , It is found that it can be used to correctly realize the original formula , see Realization 2.
3. Experimental verification
① Cross entropy loss accuracy , Compare with the standard cross entropy :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=None, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.0018)
② Label smoothing results display :
loss1 = nn.CrossEntropyLoss()
loss2 = CELoss(label_smooth=0.05, class_num=3)
x = torch.tensor([[1, 8, 1], [1, 1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0018) tensor(0.2352)
Another set of results :
x = torch.tensor([[0.1, 8, 0.1], [0.1, 0.1, 8]], dtype=torch.float)
y = torch.tensor([1, 2])
print(loss1(x, y), loss2(x, y))
# tensor(0.0007) tensor(0.2641)
analysis : After widening the gap between the model output values , The original cross entropy will decrease , However, the smoother labels become larger . This also reflects the label after smoothing , It's not that the probability is closer to 1 The better , But close to something less than 1 Value , This makes the output of the model no longer higher (+∞) The better .
边栏推荐
- 优麒麟系统安装BeyondComare
- DevCon. Exe export output to the specified file
- 服务器正文21:不同编译器对预编译的处理(简单介绍msvc和gcc)
- Get to know nodejs for the first time (with cases)
- Avoidance Adjusted Climbrate
- Qt: 一个SIGNAL绑定多个SLOT
- TSDB and blockchain
- 当CNN遇见Transformer《CMT:Convolutional Neural Networks Meet Vision Transformers》
- Understanding of PID
- Web 3.0 development learning path
猜你喜欢

JS modify table font and table border style

C language (high-level) character function and string function + Exercise

SaltStack配置管理

Validate hardware DDR design with Xilinx MIG

Streamlit machine learning application development tutorial

Srs4.0 installation steps

使用Xilinx MIG验证硬件DDR设计

Image processing web application development tutorial

Leetcode skimming - super power 372 medium

智能合约安全——溢出漏洞
随机推荐
一家芯片公司倒在了B轮
SQL audit tool self introduction owls
Accumulation and development -- the way of commercialization of open source companies
SaltStack之数据系统
Libgdx learning road 02: draw game map with tiled
BM16 删除有序链表中重复的元素-II
Parity rearrangement of Bm14 linked list
ardupilot软件在环仿真与在线调试
Understanding of PID
ACM warm-up exercise 3 in 2022 summer vacation (detailed)
TSDB and blockchain
Gmoea code operation 2 -- establishment and operation of operation environment
Kotlin Android development novice tutorial
Smart contract security - overflow vulnerability
Application of time series database in bridge monitoring field
Jestson nano Object detection
Cvpr21 unsupervised anomaly detection cutpaste:self supervised learning for anomaly detection and localization
Image processing web application development tutorial
Pandownload revival tutorial
RFs self study notes (4): actual measurement model - the mixture of OK and CK, and then calculate the likelihood probability