当前位置:网站首页>Deep understanding of cross entropy loss function
Deep understanding of cross entropy loss function
2022-07-08 02:17:00 【Strawberry sauce toast】
Preface
This article refers to torch.nn.CrossEntropyLoss() documentation 1, The cross entropy loss is deeply understood from the principle and implementation details .
One 、 Cross entropy
1.1 The definition of cross entropy
hypothesis X It's a discrete random variable , p ( x ) 、 q ( x ) p(x)、q(x) p(x)、q(x) by X Two probability distributions of , The definition of cross entropy is as follows :
H ( q , p ) = − ∑ x q ( x ) l o g p ( x ) H(q,p)=-\sum_xq(x)log\ p(x) H(q,p)=−x∑q(x)log p(x)
Cross entropy can be used to measure the similarity between two distributions . The smaller the cross entropy , p 、 q p、q p、q The more similar the two distributions are , When p = q p=q p=q when , H ( p , q ) H(p,q) H(p,q) To achieve the minimum .
1.2 Cross entropy loss
In the classification problem , Cross entropy loss (Cross Entropy Loss) For the definition of :
l ( y , y ^ ) = − ∑ j = 1 q y j l o g y ^ l(y,\hat y)=-\sum_{j=1}^qy_jlog\hat y l(y,y^)=−j=1∑qyjlogy^
In style , y y y Label the category of the sample ( The length is q q q Of one-hot Coding vector ; y ^ \hat y y^ The output probability predicted for the model .
Two 、 Maximum likelihood estimation
Likelihood : The form of the distribution function of a given population , Estimate the parameters of the model distribution function according to the probability of observed events .2
probability : When the population distribution function is known , Predict the probability of the next event .
2.1 Likelihood function
If overall X It belongs to discrete type , Its distribution law is P { X = x } = p ( x ; θ ) , θ ∈ Θ P\{X=x\}=p(x;\theta),\theta\in \Theta P{ X=x}=p(x;θ),θ∈Θ The form of is known , θ \theta θ Is the parameter to be estimated , Θ \Theta Θ by θ \theta θ Possible value range .
hypothesis x 1 , x 2 , x 3 , . . . , x n x_1, x_2, x_3,...,x_n x1,x2,x3,...,xn Is corresponding to the sample X 1 , X 2 , X 3 , . . . , X n X_1, X_2, X_3,...,X_n X1,X2,X3,...,Xn The sample values of , event X 1 = x 1 , X 2 = x 2 , X 3 = x 3 , . . . , X n = x n X_1=x_1, X_2=x_2, X_3=x_3,...,X_n=x_n X1=x1,X2=x2,X3=x3,...,Xn=xn The probability of that happening is zero :
L ( θ ) = L ( x 1 , x 2 , x 3 , . . . , x n ; θ ) = ∏ i = 1 n p ( x i ; θ ) L(\theta)=L(x_1, x_2, x_3, ...,x_n;\theta)=\prod_{i=1}^np(x_i;\theta) L(θ)=L(x1,x2,x3,...,xn;θ)=i=1∏np(xi;θ)
In style , p ( x i ; θ ) p(x_i;\theta) p(xi;θ) For events X i = x i X_i=x_i Xi=xi Probability of occurrence . L ( θ ) L(\theta) L(θ) along with θ \theta θ The value of , It is called the likelihood function of the sample .
Show me your intention for easy understanding : Likelihood function L ( θ ) L(\theta) L(θ) For events { X 1 = x 1 , X 2 = x 2 , X 3 = x 3 , . . . , X n = x n } \{X_1=x_1, X_2=x_2, X_3=x_3,...,X_n=x_n\} { X1=x1,X2=x2,X3=x3,...,Xn=xn} Probability of occurrence .
2.2 Maximum likelihood estimation
The basic idea : Fixed sample observations x 1 , x 2 , x 3 , . . . , x n x_1, x_2, x_3,...,x_n x1,x2,x3,...,xn, stay θ \theta θ Select the estimation that maximizes the likelihood function within the possible range of values θ ^ \hat \theta θ^, namely :
L ( x 1 , x 2 , . . . , x n ; θ ^ ) = max θ ∈ Θ L ( x 1 , x 2 , . . . , x n ; θ ) L(x_1,x_2,...,x_n;\hat \theta)=\mathop{\max}\limits_{\theta \in\Theta} L(x_1,x_2,...,x_n;\theta) L(x1,x2,...,xn;θ^)=θ∈ΘmaxL(x1,x2,...,xn;θ)
The obtained parameter estimation θ ^ \hat \theta θ^ And sample value x 1 , x 2 , . . . , x n x_1,x_2,...,x_n x1,x2,...,xn of , Write it down as θ ^ ( x 1 , x 2 , . . . , x n ) \hat \theta(x_1, x_2,...,x_n) θ^(x1,x2,...,xn), It's called a parameter θ \theta θ Maximum likelihood estimate of .
2.3 Maximum likelihood estimation in classification problems
Assuming that K K K Classification problem , It is known that n n n Samples , For model parameters p ( x ( i ) , y ( i ) ) p(x^{(i)}, y^{(i)}) p(x(i),y(i)) Estimate , Then the likelihood function is :
L ( ( x ( i ) , y ( i ) ; p ) = ∏ i = 1 n ∏ k = 1 K p y ( k ) L((x^{(i)},y^{(i)};p)=\prod_{i=1}^n \prod_{k=1}^K p^{y(k)} L((x(i),y(i);p)=i=1∏nk=1∏Kpy(k)
Usually, the optimization problem is to take the minimum value instead of the maximum value , therefore Maximizing likelihood function Can be converted to Minimize the negative log likelihood function , Negative log likelihood (negative likelihood) Function is :
− l o g L ( ( x ( i ) , y ( i ) ; p ) = − ∑ i = 1 n ∑ k = 1 K y ( k ) l o g p k -log\ L((x^{(i)},y^{(i)};p)=-\sum_{i=1}^n\sum_{k=1}^Ky(k)log\ p_k −log L((x(i),y(i);p)=−i=1∑nk=1∑Ky(k)log pk
contrast 1.2 The definition of section cross entropy loss function is known : Minimizing cross entropy loss function and minimizing negative log likelihood function are equivalent in formula .
therefore , It can be downloaded from Maximize sample likelihood To understand the classification model, cross entropy is selected as the loss function .
3、 ... and 、 Realization of cross entropy loss function
torch.nn.CrossEntropyLoss() The description document mentions :
Note that this case is equivalent to the combination of LogSoftmax and NLLLoss.
Pytorch in log_softmax Has been implemented in blog 3 In detail .
here , Mainly through code understanding Negative log likelihood And Cross entropy The connection between ( Reference resources 4).
3.1 Negative log likelihood loss function (Negative Log Likelihood Loss)
Definition of negative log likelihood loss function :
n l l l o s s = − 1 N ∑ i = 1 N y i l o g y ^ = − 1 N ∑ i = 1 N y i ( l o g s o f t m a x ) nllloss=-\frac{1}{N}\sum_{i=1}^N y_ilog\ \hat y=-\frac{1}{N}\sum_{i=1}^N y_i\ (logsoftmax) nllloss=−N1i=1∑Nyilog y^=−N1i=1∑Nyi (logsoftmax)
In style , N N N Is the number of samples , y i y_i yi by one-hot Encoded real sample label , y ^ \hat y y^ Is the output probability vector of the model .
>>> import torch
>>> import torch.nn.functional as F
>>> import torch.nn as nn
>>> X = torch.randn(5, 5) # establish 5*5 The sample of ( Number of samples =5,5 Classification problem )
>>> label = torch.tensor([0, 2, 3, 4, 1]) # 5 Real tags for samples
>>> label_one_hot = F.one_hot(label).float()
>>> P = F.log_softmax(X, dim=1) # Convert the output into probability ( Anti overflow treatment is done here )
''' To achieve nll loss'''
>>> nllloss = -torch.sum(label_one_hot * P) / label.shape[0]
tensor(1.9052)
''' call pytorch API Find the negative log likelihood loss '''
>>> nllloss_1 = F.nll_loss(P, label) # There is no need to make one-hot Encoding processing
tensor(1.9052)
''' call pytorch API Find the cross entropy loss '''
>>> cross_entropy_loss = F.cross_entropy(X, label)
tensor(1.9052)
The final implementation results are consistent .
3.2 To achieve CrossEntropyLoss function
Finally, the self - implemented code is given :
''' Customize log-softmax function , Normalize the model output and prevent overflow '''
def log_softmax(X):
c, _ = torch.max(X, dim=1, keepdim=True)
log_sum_exp = c + torch.log(torch.sum(torch.exp(X - c), dim=1, keepdim=True))
return X - log_sum_exp
''' Custom negative log likelihood function '''
def nll_loss(P_k, label):
label_one_hot = F.one_hot(label)
return -torch.sum(label_one_hot * P_k) / label.shape[0] # Here, take the mean value of all samples ( because cross_entropy Default 'reduction=mean')
>>> X = torch.randn(5, 5)
>>> label = torch.tensor([0,2,3,4,1])
>>> nll_loss(log_softmax(X), label) == F.cross_entropy(X, label)
tensor(True)
The final output is True, explain torch.nn.CrossEntropy() The implementation process of is consistent with the customized implementation .
边栏推荐
- metasploit
- "Hands on learning in depth" Chapter 2 - preparatory knowledge_ 2.1 data operation_ Learning thinking and exercise answers
- Node JS maintains a long connection
- 《ClickHouse原理解析与应用实践》读书笔记(7)
- 《通信软件开发与应用》课程结业报告
- 发现值守设备被攻击后分析思路
- 如何用Diffusion models做interpolation插值任务?——原理解析和代码实战
- 需要思考的地方
- [recommendation system paper reading] recommendation simulation user feedback based on Reinforcement Learning
- image enhancement
猜你喜欢

Ml self realization / logistic regression / binary classification

OpenGL/WebGL着色器开发入门指南

A comprehensive and detailed explanation of static routing configuration, a quick start guide to static routing

Nacos microservice gateway component +swagger2 interface generation

leetcode 869. Reordered Power of 2 | 869. 重新排序得到 2 的幂(状态压缩)

Leetcode featured 200 -- linked list

谈谈 SAP iRPA Studio 创建的本地项目的云端部署问题

Opengl/webgl shader development getting started guide

文盘Rust -- 给程序加个日志
![[knowledge map paper] r2d2: knowledge map reasoning based on debate dynamics](/img/e5/646ae977b8a0dc1b1ac2250602a2b9.jpg)
[knowledge map paper] r2d2: knowledge map reasoning based on debate dynamics
随机推荐
Infrared dim small target detection: common evaluation indicators
Principle of least square method and matlab code implementation
阿南的判断
Random walk reasoning and learning in large-scale knowledge base
Height of life
Introduction to Microsoft ad super Foundation
Applet running under the framework of fluent 3.0
很多小夥伴不太了解ORM框架的底層原理,這不,冰河帶你10分鐘手擼一個極簡版ORM框架(趕快收藏吧)
leetcode 869. Reordered Power of 2 | 869. 重新排序得到 2 的幂(状态压缩)
burpsuite
leetcode 866. Prime Palindrome | 866. prime palindromes
Anan's judgment
Neural network and deep learning-5-perceptron-pytorch
From starfish OS' continued deflationary consumption of SFO, the value of SFO in the long run
Leetcode question brushing record | 27_ Removing Elements
adb工具介绍
Is NPDP recognized in China? Look at it and you'll see!
Completion report of communication software development and Application
C language -cmake cmakelists Txt tutorial
WPF custom realistic wind radar chart control