当前位置:网站首页>Explain NN in pytorch in simple terms CrossEntropyLoss
Explain NN in pytorch in simple terms CrossEntropyLoss
2022-07-05 09:06:00 【aelum】
Author's brief introduction : Non Coban transcoding , We are constantly enriching our technology stack
️ Blog home page :https://raelum.blog.csdn.net
Main areas :NLP、RS、GNN
If this article helps you , Can pay attention to ️ + give the thumbs-up + Collection + Leaving a message. , This will be the biggest motivation for my creation
Catalog
One 、 Preface
nn.CrossEntropyLoss
It is often used as the loss function of multi classification problems ( Readers who don't know about cross entropy can see mine This article ), This article will focus on PyTorch Of Official documents Explain the important knowledge points one by one ( I won't explain everything ).
import torch
import torch.nn as nn
Two 、 Theoretical basis
about C ( C > 2 ) C\,(C>2) C(C>2) Classification problem , Don't think about it first batch The circumstances of , Set the output of neural network ( Not yet Softmax) by { x c } c = 1 C \{x_c\}_{c=1}^C { xc}c=1C, after Softmax Get back
q i = exp ( x i ) ∑ c = 1 C exp ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=∑c=1Cexp(xc)exp(xi)
Thus, the cross entropy loss of this sample is
H ( p , q ) = − ∑ i = 1 C p i log q i = − ∑ i = 1 C p i log exp ( x i ) ∑ c = 1 C exp ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−i=1∑Cpilogqi=−i=1∑Cpilog∑c=1Cexp(xc)exp(xi)
among ( p 1 , p 2 , ⋯ , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,⋯,pC) yes One-Hot vector .
You may as well make p y = 1 ( y ∈ { 1 , 2 , ⋯ , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y∈{ 1,2,⋯,C}), Others are 0 0 0, So the above formula becomes
H ( p , q ) = − log exp ( x y ) ∑ c = 1 C exp ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−log∑c=1Cexp(xc)exp(xy)
Now consider batch The circumstances of , Might as well set batch size by N N N, The output of the neural network is { x n c } n c , n = 1 , ⋯ , N , c = 1 , ⋯ , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C { xnc}nc,n=1,⋯,N,c=1,⋯,C, The first n n n The real category of samples is recorded as y n ( y n ∈ { 1 , 2 , ⋯ , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn∈{ 1,2,⋯,C}), The first n n n The cross entropy loss of samples is recorded as l n l_n ln, Then follow the above formula
l n = − log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−log∑c=1Cexp(xnc)exp(xn,yn)
Next, let's discuss some special situations . When Data imbalance when ( The number of samples in a certain class is particularly large , The number of samples in the other category is particularly small ), We need to arrange a weight for each kind of loss to balance . The weight of w = ( w 1 , w 2 , ⋯ , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,⋯,wC).
The model is easy in the one with the largest number of samples ( Or a few ) Over fitting on class , So for those classes with a small number of samples , We need to set a higher weight , In this way, once the model makes an error in predicting the labels of these classes , Will be punished more
After arranging the weight , The corresponding loss is
l n = − w y n log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−wynlog∑c=1Cexp(xnc)exp(xn,yn)
After calculating l 1 , l 2 , ⋯ , l N l_1,l_2,\cdots,l_N l1,l2,⋯,lN after , We can put them all at once All return ( Corresponding reduction=none
), You can also return their mean value ( Corresponding reduction=mean
), You can also return their and ( Corresponding reduction=sum
):
ℓ = { ( l 1 , ⋯ , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} ℓ=⎩⎪⎨⎪⎧(l1,⋯,lN),∑n=1Nln/∑n=1Nwyn,∑n=1Nln,reduction=nonereduction=meanreduction=sum
stay NLP Tasks , We often add filler elements to the end of each sequence , In this way, sequences of different lengths can be loaded in batches . During training , We don't want the filler elements predicted by the network to be included in the loss function . It is advisable to set the index of filler element in the thesaurus as i i i, Then deal with l n l_n ln Make the following amendments :
l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) , where I ( x ) = { 1 , x is True 0 , x is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} ln=−wyn⋅I(yn=i)⋅log∑c=1Cexp(xnc)exp(xn,yn),whereI(x)={ 1,0,xis Truexis False
in addition , In this scenario reduction=mean
The corresponding loss becomes
ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} ℓ=n=1∑N∑n=1Nwyn⋅I(yn=i)ln
It should be noted that , stay PyTorch in y n ∈ { 0 , 1 , ⋯ , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn∈{ 0,1,⋯,C−1}, Here we use { 1 , 2 , ⋯ , C } \{1,2,\cdots,C\} { 1,2,⋯,C} In order to connect the context more naturally
3、 ... and 、 main parameter
nn.CrossEntropyLoss
The main parameters are as follows :
nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
️
size_average
andreduce
Parameter is deprecated , In its placereduction
Parameters , So I won't explain it here
With the bedding in front , We can easily understand these parameters :
weight
: The length is C C C Tensor , It is generally used when the data is unbalanced ;ignore_index
: Index of categories that need to be ignored , The default is − 100 -100 −100, That is not to ignore ;reduction
: Decide how to return the loss . bynone
When to return to N N N Loss of samples , bymean
When to return to N N N The average loss of samples , bysum
When to return to N N N The sum of the loss of samples . The default ismean
;label_smoothing
: Decide whether to turn on label smoothing ( Readers who do not understand label smoothing can refer to This article ), Values in [ 0 , 1 ] [0,1] [0,1] Inside . The default is 0 0 0, That is, do not open .
3.1 Input and output
Input is divided into input
and target
,input
Usually it is ( N , C ) (N,C) (N,C) The shape of the ( namely batch_size × num_classes
),target
Usually it is ( N , ) (N,) (N,) The shape of the , Each of these components is located in [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C−1]∩Z in , Represents the category to which the sample belongs .
input
andtarget
It can also be other types of input , But this article only discusses the most widely used inputinput
It is the original output of neural network ( Not passed Softmax),nn.CrossEntropyLoss
It will be automatically applied Softmax
torch.manual_seed(0)
batch_size = 3
num_classes = 5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')
inputs = torch.randn(batch_size, num_classes) # Avoid and input Keyword conflict ( Of course, it doesn't matter )
target = torch.randint(num_classes, size=(batch_size, ))
print(criterion_1(inputs, target)) # Output 3 A sample of loss
# tensor([1.4639, 3.0493, 2.3056])
print(criterion_2(inputs, target)) # Output 3 A sample of loss The average of
# tensor(2.2729)
print(criterion_3(inputs, target)) # Output 3 A sample of loss And
# tensor(6.8188)
print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
# tensor(True)
print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
# tensor(True)
Four 、 Start from scratch nn.CrossEntropyLoss
In order to deepen our understanding of , Next, let's start from scratch nn.CrossEntropyLoss
( Of course, it will be different from the official , In order to pursue readability, it will be implemented in a fool's way ).
First determine the framework ( For simplicity, we don't consider label_smoothing
):
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
pass
For ease of calculation , We rewrite the loss calculation formula in Chapter 2
l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ∑ c = 1 C exp ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wyn⋅I(yn=i)⋅[−xn,yn+logc=1∑Cexp(xnc)]
Adopt more consistent Python To rewrite the above formula
l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ∑ c = 1 C exp ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]⋅I(yn=i)⋅[−xn[yn]+logc=1∑Cexp(xn[c])]
among w = ( w 1 , ⋯ , w C ) , x n = ( x n 1 , ⋯ , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,⋯,wC),xn=(xn1,⋯,xnC). Re order X = ( x 1 ; ⋯ ; x N ) , y = ( y 1 , ⋯ , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;⋯;xN),y=(y1,⋯,yC), Obviously X {\bf X} X It's ours input
, y \boldsymbol{y} y Namely target
, So we can do batch calculation
( l 1 , ⋯ , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) , y ] + log ( sum ( exp ( X ) , dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,⋯,lN)=w[y]∗I(y=i)∗(−X[range(len(y)),y]+log(sum(exp(X),dim=1)))
among ∗ * ∗ Means multiply by elements . The above formula adopts the broadcasting mechanism .
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
if self.weight is not None:
n_samples_weight = self.weight[target] # The weight of each sample
else:
n_samples_weight = torch.ones_like(target).float() # If no weight is provided, all are by default 1
indicator = (target != self.ignore_index).long().float() # long() Method can transform Boolean tensor into 0-1 tensor
raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
result = n_samples_weight * indicator * raw_loss
if self.reduction == 'mean':
return torch.sum(result) / n_samples_weight.dot(indicator)
elif self.reduction == 'sum':
return torch.sum(result)
else:
return result
The output is consistent with PyTorch Official nn.CrossEntropyLoss
It's exactly the same , No more here , Readers can verify .
边栏推荐
- 12. Dynamic link library, DLL
- Wechat H5 official account to get openid climbing account
- Confusing basic concepts member variables local variables global variables
- 特征工程
- Use and programming method of ros-8 parameters
- Halcon affine transformations to regions
- IT冷知识(更新ing~)
- Beautiful soup parsing and extracting data
- Introduction Guide to stereo vision (3): Zhang calibration method of camera calibration [ultra detailed and worthy of collection]
- Understanding rotation matrix R from the perspective of base transformation
猜你喜欢
Count of C # LINQ source code analysis
Halcon: check of blob analysis_ Blister capsule detection
TF coordinate transformation of common components of ros-9 ROS
Nodemon installation and use
Applet (subcontracting)
我从技术到产品经理的几点体会
Codeworks round 639 (Div. 2) cute new problem solution
2020 "Lenovo Cup" National College programming online Invitational Competition and the third Shanghai University of technology programming competition
Solution to the problems of the 17th Zhejiang University City College Program Design Competition (synchronized competition)
Rebuild my 3D world [open source] [serialization-3] [comparison between colmap and openmvg]
随机推荐
Codeworks round 638 (Div. 2) cute new problem solution
优先级队列(堆)
My experience from technology to product manager
Understanding rotation matrix R from the perspective of base transformation
Programming implementation of subscriber node of ROS learning 3 subscriber
交通运输部、教育部:广泛开展水上交通安全宣传和防溺水安全提醒
ORACLE进阶(三)数据字典详解
深度学习模型与湿实验的结合,有望用于代谢通量分析
C [essential skills] use of configurationmanager class (use of file app.config)
Driver's license physical examination hospital (114-2 hang up the corresponding hospital driver physical examination)
Attention is all you need
一题多解,ASP.NET Core应用启动初始化的N种方案[上篇]
It cold knowledge (updating ing~)
Confusion matrix
Adaboost使用
[daiy4] copy of JZ35 complex linked list
Huber Loss
Introduction Guide to stereo vision (4): DLT direct linear transformation of camera calibration [recommended collection]
Kubedm series-00-overview
MPSoC QSPI flash upgrade method