当前位置:网站首页>深入浅出PyTorch中的nn.CrossEntropyLoss
深入浅出PyTorch中的nn.CrossEntropyLoss
2022-07-05 09:01:00 【aelum】
作者简介:非科班转码,正在不断丰富自己的技术栈
️ 博客主页:https://raelum.blog.csdn.net
主要领域:NLP、RS、GNN
如果这篇文章有帮助到你,可以关注️ + 点赞 + 收藏 + 留言,这将是我创作的最大动力
一、前言
nn.CrossEntropyLoss
常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。
import torch
import torch.nn as nn
二、理论基础
对于 C ( C > 2 ) C\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 { x c } c = 1 C \{x_c\}_{c=1}^C { xc}c=1C,经过 Softmax 后得到
q i = exp ( x i ) ∑ c = 1 C exp ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=∑c=1Cexp(xc)exp(xi)
从而该样本的交叉熵损失为
H ( p , q ) = − ∑ i = 1 C p i log q i = − ∑ i = 1 C p i log exp ( x i ) ∑ c = 1 C exp ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−i=1∑Cpilogqi=−i=1∑Cpilog∑c=1Cexp(xc)exp(xi)
其中 ( p 1 , p 2 , ⋯ , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,⋯,pC) 是 One-Hot 向量。
不妨令 p y = 1 ( y ∈ { 1 , 2 , ⋯ , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y∈{ 1,2,⋯,C}),其余为 0 0 0,因此上式变为
H ( p , q ) = − log exp ( x y ) ∑ c = 1 C exp ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−log∑c=1Cexp(xc)exp(xy)
现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 { x n c } n c , n = 1 , ⋯ , N , c = 1 , ⋯ , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C { xnc}nc,n=1,⋯,N,c=1,⋯,C,第 n n n 个样本的真实类别记为 y n ( y n ∈ { 1 , 2 , ⋯ , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn∈{ 1,2,⋯,C}),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有
l n = − log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−log∑c=1Cexp(xnc)exp(xn,yn)
接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯ , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,⋯,wC)。
模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚
安排了权重后,相应的损失为
l n = − w y n log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−wynlog∑c=1Cexp(xnc)exp(xn,yn)
计算完 l 1 , l 2 , ⋯ , l N l_1,l_2,\cdots,l_N l1,l2,⋯,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none
),也可以返回它们的均值(对应 reduction=mean
),还可以返回它们的和(对应 reduction=sum
):
ℓ = { ( l 1 , ⋯ , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} ℓ=⎩⎪⎨⎪⎧(l1,⋯,lN),∑n=1Nln/∑n=1Nwyn,∑n=1Nln,reduction=nonereduction=meanreduction=sum
在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为 i i i,则此时应对 l n l_n ln 作如下修正:
l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) , where I ( x ) = { 1 , x is True 0 , x is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} ln=−wyn⋅I(yn=i)⋅log∑c=1Cexp(xnc)exp(xn,yn),whereI(x)={ 1,0,xis Truexis False
另外,该场景下的 reduction=mean
对应的损失变为
ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} ℓ=n=1∑N∑n=1Nwyn⋅I(yn=i)ln
需要注意的是,在PyTorch中 y n ∈ { 0 , 1 , ⋯ , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn∈{ 0,1,⋯,C−1},这里我们之所以用 { 1 , 2 , ⋯ , C } \{1,2,\cdots,C\} { 1,2,⋯,C} 是为了更自然地衔接上下文
三、主要参数
nn.CrossEntropyLoss
的主要参数如下:
nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
️
size_average
和reduce
参数已经弃用,取而代之的是reduction
参数,所以这里不再讲解
有了前面的铺垫,我们就可以很容易理解这些参数了:
weight
:长度为 C C C 的张量,一般在数据不平衡时才会使用;ignore_index
:需要忽略的类别的索引,默认为 − 100 -100 −100,即不忽略;reduction
:决定以何种形式返回损失。为none
时返回 N N N 个样本的损失,为mean
时返回 N N N 个样本的损失均值,为sum
时返回 N N N 个样本的损失的和。默认为mean
;label_smoothing
:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。
3.1 输入与输出
输入分为 input
和 target
,input
通常为 ( N , C ) (N,C) (N,C) 的形状(即 batch_size × num_classes
),target
通常为 ( N , ) (N,) (N,) 的形状,其中的每个分量均位于 [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C−1]∩Z 中,代表样本属于的类别。
input
和target
还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入input
是神经网络的原始输出(未经过 Softmax),nn.CrossEntropyLoss
会自动对其应用 Softmax
torch.manual_seed(0)
batch_size = 3
num_classes = 5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')
inputs = torch.randn(batch_size, num_classes) # 避免与input关键字冲突(当然这无所谓)
target = torch.randint(num_classes, size=(batch_size, ))
print(criterion_1(inputs, target)) # 输出3个样本的loss
# tensor([1.4639, 3.0493, 2.3056])
print(criterion_2(inputs, target)) # 输出3个样本的loss的均值
# tensor(2.2729)
print(criterion_3(inputs, target)) # 输出3个样本的loss的和
# tensor(6.8188)
print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
# tensor(True)
print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
# tensor(True)
四、从零开始实现 nn.CrossEntropyLoss
为了加深理解,接下来我们从零开始实现 nn.CrossEntropyLoss
(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。
首先确定框架(为简便起见这里不考虑 label_smoothing
):
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
pass
为方便计算,我们对第二章节的损失计算公式进行改写
l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ∑ c = 1 C exp ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wyn⋅I(yn=i)⋅[−xn,yn+logc=1∑Cexp(xnc)]
采用更符合 Python 的表述方式来改写上式
l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ∑ c = 1 C exp ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]⋅I(yn=i)⋅[−xn[yn]+logc=1∑Cexp(xn[c])]
其中 w = ( w 1 , ⋯ , w C ) , x n = ( x n 1 , ⋯ , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,⋯,wC),xn=(xn1,⋯,xnC)。再令 X = ( x 1 ; ⋯ ; x N ) , y = ( y 1 , ⋯ , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;⋯;xN),y=(y1,⋯,yC),则显然 X {\bf X} X 就是我们的 input
, y \boldsymbol{y} y 就是 target
,于是我们可以进行批量计算
( l 1 , ⋯ , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) , y ] + log ( sum ( exp ( X ) , dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,⋯,lN)=w[y]∗I(y=i)∗(−X[range(len(y)),y]+log(sum(exp(X),dim=1)))
其中 ∗ * ∗ 代表按元素相乘。上式采用了广播机制。
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
if self.weight is not None:
n_samples_weight = self.weight[target] # 每个样本的权重
else:
n_samples_weight = torch.ones_like(target).float() # 不提供权重则默认全为1
indicator = (target != self.ignore_index).long().float() # long()方法可以将布尔型张量转化成0-1张量
raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
result = n_samples_weight * indicator * raw_loss
if self.reduction == 'mean':
return torch.sum(result) / n_samples_weight.dot(indicator)
elif self.reduction == 'sum':
return torch.sum(result)
else:
return result
输出结果与 PyTorch 官方的 nn.CrossEntropyLoss
的完全相同,这里不再展示,读者可自行验证。
边栏推荐
- Nodejs modularization
- Halcon blob analysis (ball.hdev)
- MPSoC QSPI Flash 升级办法
- Count of C # LINQ source code analysis
- 2309. 兼具大小写的最好英文字母
- Introduction Guide to stereo vision (1): coordinate system and camera parameters
- 某公司文件服务器迁移方案
- golang 基础 —— golang 向 mysql 插入的时间数据和本地时间不一致
- nodejs_ fs. writeFile
- Typescript hands-on tutorial, easy to understand
猜你喜欢
Introduction Guide to stereo vision (2): key matrix (essential matrix, basic matrix, homography matrix)
Halcon Chinese character recognition
My experience from technology to product manager
[beauty of algebra] singular value decomposition (SVD) and its application to linear least squares solution ax=b
Hello everyone, welcome to my CSDN blog!
Introduction Guide to stereo vision (7): stereo matching
Ros-11 common visualization tools
嗨 FUN 一夏,与 StarRocks 一起玩转 SQL Planner!
Add discount recharge and discount shadow ticket plug-ins to the resource realization applet
[code practice] [stereo matching series] Classic ad census: (5) scan line optimization
随机推荐
Mengxin summary of LIS (longest ascending subsequence) topics
MPSoC QSPI flash upgrade method
Codeworks round 638 (Div. 2) cute new problem solution
fs. Path module
ROS learning 4 custom message
golang 基础 —— golang 向 mysql 插入的时间数据和本地时间不一致
C#绘制带控制点的Bezier曲线,用于点阵图像及矢量图形
[Niuke brush questions day4] jz55 depth of binary tree
2020 "Lenovo Cup" National College programming online Invitational Competition and the third Shanghai University of technology programming competition
Nodemon installation and use
Driver's license physical examination hospital (114-2 hang up the corresponding hospital driver physical examination)
How many checks does kubedm series-01-preflight have
Introduction Guide to stereo vision (5): dual camera calibration [no more collection, I charge ~]
nodejs_ fs. writeFile
Huber Loss
[beauty of algebra] solution method of linear equations ax=0
ECMAScript6介绍及环境搭建
我从技术到产品经理的几点体会
Blogger article navigation (classified, real-time update, permanent top)
2311. 小于等于 K 的最长二进制子序列