当前位置:网站首页>深入浅出PyTorch中的nn.CrossEntropyLoss
深入浅出PyTorch中的nn.CrossEntropyLoss
2022-07-05 09:01:00 【aelum】
作者简介:非科班转码,正在不断丰富自己的技术栈
️ 博客主页:https://raelum.blog.csdn.net
主要领域:NLP、RS、GNN
如果这篇文章有帮助到你,可以关注️ + 点赞 + 收藏 + 留言,这将是我创作的最大动力

一、前言
nn.CrossEntropyLoss 常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。
import torch
import torch.nn as nn
二、理论基础
对于 C ( C > 2 ) C\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 { x c } c = 1 C \{x_c\}_{c=1}^C { xc}c=1C,经过 Softmax 后得到
q i = exp ( x i ) ∑ c = 1 C exp ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=∑c=1Cexp(xc)exp(xi)
从而该样本的交叉熵损失为
H ( p , q ) = − ∑ i = 1 C p i log q i = − ∑ i = 1 C p i log exp ( x i ) ∑ c = 1 C exp ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−i=1∑Cpilogqi=−i=1∑Cpilog∑c=1Cexp(xc)exp(xi)
其中 ( p 1 , p 2 , ⋯ , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,⋯,pC) 是 One-Hot 向量。
不妨令 p y = 1 ( y ∈ { 1 , 2 , ⋯ , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y∈{ 1,2,⋯,C}),其余为 0 0 0,因此上式变为
H ( p , q ) = − log exp ( x y ) ∑ c = 1 C exp ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−log∑c=1Cexp(xc)exp(xy)
现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 { x n c } n c , n = 1 , ⋯ , N , c = 1 , ⋯ , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C { xnc}nc,n=1,⋯,N,c=1,⋯,C,第 n n n 个样本的真实类别记为 y n ( y n ∈ { 1 , 2 , ⋯ , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn∈{ 1,2,⋯,C}),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有
l n = − log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−log∑c=1Cexp(xnc)exp(xn,yn)
接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯ , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,⋯,wC)。
模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚
安排了权重后,相应的损失为
l n = − w y n log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−wynlog∑c=1Cexp(xnc)exp(xn,yn)
计算完 l 1 , l 2 , ⋯ , l N l_1,l_2,\cdots,l_N l1,l2,⋯,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none),也可以返回它们的均值(对应 reduction=mean),还可以返回它们的和(对应 reduction=sum):
ℓ = { ( l 1 , ⋯ , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} ℓ=⎩⎪⎨⎪⎧(l1,⋯,lN),∑n=1Nln/∑n=1Nwyn,∑n=1Nln,reduction=nonereduction=meanreduction=sum
在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为 i i i,则此时应对 l n l_n ln 作如下修正:
l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) , where I ( x ) = { 1 , x is True 0 , x is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} ln=−wyn⋅I(yn=i)⋅log∑c=1Cexp(xnc)exp(xn,yn),whereI(x)={ 1,0,xis Truexis False
另外,该场景下的 reduction=mean 对应的损失变为
ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} ℓ=n=1∑N∑n=1Nwyn⋅I(yn=i)ln
需要注意的是,在PyTorch中 y n ∈ { 0 , 1 , ⋯ , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn∈{ 0,1,⋯,C−1},这里我们之所以用 { 1 , 2 , ⋯ , C } \{1,2,\cdots,C\} { 1,2,⋯,C} 是为了更自然地衔接上下文
三、主要参数
nn.CrossEntropyLoss 的主要参数如下:
nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
️
size_average和reduce参数已经弃用,取而代之的是reduction参数,所以这里不再讲解
有了前面的铺垫,我们就可以很容易理解这些参数了:
weight:长度为 C C C 的张量,一般在数据不平衡时才会使用;ignore_index:需要忽略的类别的索引,默认为 − 100 -100 −100,即不忽略;reduction:决定以何种形式返回损失。为none时返回 N N N 个样本的损失,为mean时返回 N N N 个样本的损失均值,为sum时返回 N N N 个样本的损失的和。默认为mean;label_smoothing:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。
3.1 输入与输出
输入分为 input 和 target,input 通常为 ( N , C ) (N,C) (N,C) 的形状(即 batch_size × num_classes),target 通常为 ( N , ) (N,) (N,) 的形状,其中的每个分量均位于 [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C−1]∩Z 中,代表样本属于的类别。
input和target还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入input是神经网络的原始输出(未经过 Softmax),nn.CrossEntropyLoss会自动对其应用 Softmax
torch.manual_seed(0)
batch_size = 3
num_classes = 5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')
inputs = torch.randn(batch_size, num_classes) # 避免与input关键字冲突(当然这无所谓)
target = torch.randint(num_classes, size=(batch_size, ))
print(criterion_1(inputs, target)) # 输出3个样本的loss
# tensor([1.4639, 3.0493, 2.3056])
print(criterion_2(inputs, target)) # 输出3个样本的loss的均值
# tensor(2.2729)
print(criterion_3(inputs, target)) # 输出3个样本的loss的和
# tensor(6.8188)
print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
# tensor(True)
print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
# tensor(True)
四、从零开始实现 nn.CrossEntropyLoss
为了加深理解,接下来我们从零开始实现 nn.CrossEntropyLoss(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。
首先确定框架(为简便起见这里不考虑 label_smoothing):
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
pass
为方便计算,我们对第二章节的损失计算公式进行改写
l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ∑ c = 1 C exp ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wyn⋅I(yn=i)⋅[−xn,yn+logc=1∑Cexp(xnc)]
采用更符合 Python 的表述方式来改写上式
l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ∑ c = 1 C exp ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]⋅I(yn=i)⋅[−xn[yn]+logc=1∑Cexp(xn[c])]
其中 w = ( w 1 , ⋯ , w C ) , x n = ( x n 1 , ⋯ , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,⋯,wC),xn=(xn1,⋯,xnC)。再令 X = ( x 1 ; ⋯ ; x N ) , y = ( y 1 , ⋯ , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;⋯;xN),y=(y1,⋯,yC),则显然 X {\bf X} X 就是我们的 input, y \boldsymbol{y} y 就是 target,于是我们可以进行批量计算
( l 1 , ⋯ , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) , y ] + log ( sum ( exp ( X ) , dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,⋯,lN)=w[y]∗I(y=i)∗(−X[range(len(y)),y]+log(sum(exp(X),dim=1)))
其中 ∗ * ∗ 代表按元素相乘。上式采用了广播机制。
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
if self.weight is not None:
n_samples_weight = self.weight[target] # 每个样本的权重
else:
n_samples_weight = torch.ones_like(target).float() # 不提供权重则默认全为1
indicator = (target != self.ignore_index).long().float() # long()方法可以将布尔型张量转化成0-1张量
raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
result = n_samples_weight * indicator * raw_loss
if self.reduction == 'mean':
return torch.sum(result) / n_samples_weight.dot(indicator)
elif self.reduction == 'sum':
return torch.sum(result)
else:
return result
输出结果与 PyTorch 官方的 nn.CrossEntropyLoss 的完全相同,这里不再展示,读者可自行验证。
边栏推荐
- Basic number theory -- Euler function
- It cold knowledge (updating ing~)
- C#绘制带控制点的Bezier曲线,用于点阵图像及矢量图形
- 2310. 个位数字为 K 的整数之和
- OpenFeign
- Chris LATTNER, the father of llvm: why should we rebuild AI infrastructure software
- Programming implementation of ROS learning 5-client node
- 多元线性回归(梯度下降法)
- Use and programming method of ros-8 parameters
- Codeforces Round #648 (Div. 2) E.Maximum Subsequence Value
猜你喜欢

生成对抗网络

Confusing basic concepts member variables local variables global variables
![C [essential skills] use of configurationmanager class (use of file app.config)](/img/8b/e56f87c2d0fbbb1251ec01b99204a1.png)
C [essential skills] use of configurationmanager class (use of file app.config)

Codeworks round 639 (Div. 2) cute new problem solution
![[code practice] [stereo matching series] Classic ad census: (4) cross domain cost aggregation](/img/d8/7291a5b14160600ba73810e6dd1eb5.jpg)
[code practice] [stereo matching series] Classic ad census: (4) cross domain cost aggregation

Halcon clolor_ pieces. Hedv: classifier_ Color recognition

Hello everyone, welcome to my CSDN blog!

ROS learning 4 custom message

AUTOSAR从入门到精通100讲(103)-dbc文件的格式以及创建详解

Halcon affine transformations to regions
随机推荐
Halcon snap, get the area and position of coins
Programming implementation of subscriber node of ROS learning 3 subscriber
容易混淆的基本概念 成员变量 局部变量 全局变量
OpenFeign
Infix expression evaluation
Multiple linear regression (gradient descent method)
Use and programming method of ros-8 parameters
Programming implementation of ROS learning 5-client node
Codeworks round 639 (Div. 2) cute new problem solution
2020 "Lenovo Cup" National College programming online Invitational Competition and the third Shanghai University of technology programming competition
混淆矩阵(Confusion Matrix)
location search 属性获取登录用户名
Introduction Guide to stereo vision (4): DLT direct linear transformation of camera calibration [recommended collection]
Meta tag details
Introduction Guide to stereo vision (1): coordinate system and camera parameters
nodejs_ 01_ fs. readFile
Halcon blob analysis (ball.hdev)
asp.net(c#)的货币格式化
Bit operation related operations
Halcon clolor_ pieces. Hedv: classifier_ Color recognition