当前位置:网站首页>深入浅出PyTorch中的nn.CrossEntropyLoss
深入浅出PyTorch中的nn.CrossEntropyLoss
2022-07-05 09:01:00 【aelum】
作者简介:非科班转码,正在不断丰富自己的技术栈
️ 博客主页:https://raelum.blog.csdn.net
主要领域:NLP、RS、GNN
如果这篇文章有帮助到你,可以关注️ + 点赞 + 收藏 + 留言,这将是我创作的最大动力
一、前言
nn.CrossEntropyLoss
常用作多分类问题的损失函数(对交叉熵还不了解的读者可以看我的这篇文章),本文将围绕PyTorch的官方文档对重要知识点进行逐一讲解(不会全部讲解)。
import torch
import torch.nn as nn
二、理论基础
对于 C ( C > 2 ) C\,(C>2) C(C>2) 分类问题,先不考虑 batch 的情形,设神经网络的输出(还未经过 Softmax)为 { x c } c = 1 C \{x_c\}_{c=1}^C { xc}c=1C,经过 Softmax 后得到
q i = exp ( x i ) ∑ c = 1 C exp ( x c ) q_i=\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} qi=∑c=1Cexp(xc)exp(xi)
从而该样本的交叉熵损失为
H ( p , q ) = − ∑ i = 1 C p i log q i = − ∑ i = 1 C p i log exp ( x i ) ∑ c = 1 C exp ( x c ) H(p,q)=-\sum_{i=1}^C p_i\log q_i=-\sum_{i=1}^C p_i\log\frac{\exp(x_i)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−i=1∑Cpilogqi=−i=1∑Cpilog∑c=1Cexp(xc)exp(xi)
其中 ( p 1 , p 2 , ⋯ , p C ) (p_1,p_2,\cdots,p_C) (p1,p2,⋯,pC) 是 One-Hot 向量。
不妨令 p y = 1 ( y ∈ { 1 , 2 , ⋯ , C } ) p_y=1\,(y\in\{1,2,\cdots,C\}) py=1(y∈{ 1,2,⋯,C}),其余为 0 0 0,因此上式变为
H ( p , q ) = − log exp ( x y ) ∑ c = 1 C exp ( x c ) H(p,q)=-\log\frac{\exp(x_y)}{\sum_{c=1}^C\exp(x_c)} H(p,q)=−log∑c=1Cexp(xc)exp(xy)
现在考虑有 batch 的情形,不妨设 batch size 为 N N N,神经网络的输出为 { x n c } n c , n = 1 , ⋯ , N , c = 1 , ⋯ , C \{x_{nc}\}_{nc},\;n=1,\cdots,N,\;c=1,\cdots,C { xnc}nc,n=1,⋯,N,c=1,⋯,C,第 n n n 个样本的真实类别记为 y n ( y n ∈ { 1 , 2 , ⋯ , C } ) y_n\,(y_n\in\{1,2,\cdots,C\}) yn(yn∈{ 1,2,⋯,C}),第 n n n 个样本的交叉熵损失记为 l n l_n ln,则仿照上式就有
l n = − log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−log∑c=1Cexp(xnc)exp(xn,yn)
接下来我们讨论一些特殊情形。当数据不平衡时(某一类的样本数特别多,另一类的样本数特别少),我们需要为每一类的损失安排一个权重用来平衡。权重为 w = ( w 1 , w 2 , ⋯ , w C ) \boldsymbol{w}=(w_1,w_2,\cdots,w_C) w=(w1,w2,⋯,wC)。
模型容易在样本数最多的一个(或几个)类上过拟合,因此对于那些样本数较少的类,我们需要设置更高的权重,这样模型在预测这些类的标签时一旦出错,就会受到更多的惩罚
安排了权重后,相应的损失为
l n = − w y n log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) l_n=-w_{y_n}\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})} ln=−wynlog∑c=1Cexp(xnc)exp(xn,yn)
计算完 l 1 , l 2 , ⋯ , l N l_1,l_2,\cdots,l_N l1,l2,⋯,lN 后,我们既可以一次性将它们全部返回(对应 reduction=none
),也可以返回它们的均值(对应 reduction=mean
),还可以返回它们的和(对应 reduction=sum
):
ℓ = { ( l 1 , ⋯ , l N ) , reduction=none ∑ n = 1 N l n / ∑ n = 1 N w y n , reduction=mean ∑ n = 1 N l n , reduction=sum \ell=\begin{cases} (l_1,\cdots,l_N),&\text{reduction=none} \\ \sum_{n=1}^N l_n/\sum_{n=1}^N w_{y_n},&\text{reduction=mean} \\ \sum_{n=1}^N l_n,&\text{reduction=sum} \\ \end{cases} ℓ=⎩⎪⎨⎪⎧(l1,⋯,lN),∑n=1Nln/∑n=1Nwyn,∑n=1Nln,reduction=nonereduction=meanreduction=sum
在 NLP 任务中,我们往往将填充词元添加到每个序列的末尾,这样一来不同长度的序列可以进行批量加载。训练过程中,我们不希望网络预测出的填充词元被算入损失函数中。不妨设填充词元在词表中的索引为 i i i,则此时应对 l n l_n ln 作如下修正:
l n = − w y n ⋅ I ( y n ≠ i ) ⋅ log exp ( x n , y n ) ∑ c = 1 C exp ( x n c ) , where I ( x ) = { 1 , x is True 0 , x is False l_n=-w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot\log \frac{\exp(x_{n,y_n}{})}{\sum_{c=1}^C\exp(x_{nc})},\qquad \text{where}\; \mathbb{I}(x)= \begin{cases} 1,&x\; \text{is True} \\ 0,&x\; \text{is False} \end{cases} ln=−wyn⋅I(yn=i)⋅log∑c=1Cexp(xnc)exp(xn,yn),whereI(x)={ 1,0,xis Truexis False
另外,该场景下的 reduction=mean
对应的损失变为
ℓ = ∑ n = 1 N l n ∑ n = 1 N w y n ⋅ I ( y n ≠ i ) \ell=\sum_{n=1}^N\frac{l_n}{\sum_{n=1}^Nw_{y_n}\cdot \mathbb{I}(y_n\neq i)} ℓ=n=1∑N∑n=1Nwyn⋅I(yn=i)ln
需要注意的是,在PyTorch中 y n ∈ { 0 , 1 , ⋯ , C − 1 } y_n\in\{0,1,\cdots,C-1\} yn∈{ 0,1,⋯,C−1},这里我们之所以用 { 1 , 2 , ⋯ , C } \{1,2,\cdots,C\} { 1,2,⋯,C} 是为了更自然地衔接上下文
三、主要参数
nn.CrossEntropyLoss
的主要参数如下:
nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0)
️
size_average
和reduce
参数已经弃用,取而代之的是reduction
参数,所以这里不再讲解
有了前面的铺垫,我们就可以很容易理解这些参数了:
weight
:长度为 C C C 的张量,一般在数据不平衡时才会使用;ignore_index
:需要忽略的类别的索引,默认为 − 100 -100 −100,即不忽略;reduction
:决定以何种形式返回损失。为none
时返回 N N N 个样本的损失,为mean
时返回 N N N 个样本的损失均值,为sum
时返回 N N N 个样本的损失的和。默认为mean
;label_smoothing
:决定是否开启标签平滑(不了解标签平滑的读者可参考这篇文章),数值在 [ 0 , 1 ] [0,1] [0,1] 内。默认为 0 0 0,即不开启。
3.1 输入与输出
输入分为 input
和 target
,input
通常为 ( N , C ) (N,C) (N,C) 的形状(即 batch_size × num_classes
),target
通常为 ( N , ) (N,) (N,) 的形状,其中的每个分量均位于 [ 0 , C − 1 ] ∩ Z [0,C-1] \cap \mathbb{Z} [0,C−1]∩Z 中,代表样本属于的类别。
input
和target
还可以是其他类型的输入,但本文只讨论这种使用最为广泛的输入input
是神经网络的原始输出(未经过 Softmax),nn.CrossEntropyLoss
会自动对其应用 Softmax
torch.manual_seed(0)
batch_size = 3
num_classes = 5
criterion_1 = nn.CrossEntropyLoss(reduction='none')
criterion_2 = nn.CrossEntropyLoss()
criterion_3 = nn.CrossEntropyLoss(reduction='sum')
inputs = torch.randn(batch_size, num_classes) # 避免与input关键字冲突(当然这无所谓)
target = torch.randint(num_classes, size=(batch_size, ))
print(criterion_1(inputs, target)) # 输出3个样本的loss
# tensor([1.4639, 3.0493, 2.3056])
print(criterion_2(inputs, target)) # 输出3个样本的loss的均值
# tensor(2.2729)
print(criterion_3(inputs, target)) # 输出3个样本的loss的和
# tensor(6.8188)
print(sum(criterion_1(inputs, target)) == criterion_3(inputs, target))
# tensor(True)
print(sum(criterion_1(inputs, target)) / batch_size == criterion_2(inputs, target))
# tensor(True)
四、从零开始实现 nn.CrossEntropyLoss
为了加深理解,接下来我们从零开始实现 nn.CrossEntropyLoss
(当然会和官方不同,为了追求可读性会采用傻瓜式实现)。
首先确定框架(为简便起见这里不考虑 label_smoothing
):
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
pass
为方便计算,我们对第二章节的损失计算公式进行改写
l n = w y n ⋅ I ( y n ≠ i ) ⋅ [ − x n , y n + log ∑ c = 1 C exp ( x n c ) ] l_n=w_{y_n}\cdot \mathbb{I}(y_n\neq i)\cdot[-x_{n,y_n}+\log\sum_{c=1}^C\exp(x_{nc})] ln=wyn⋅I(yn=i)⋅[−xn,yn+logc=1∑Cexp(xnc)]
采用更符合 Python 的表述方式来改写上式
l n = w [ y n ] ⋅ I ( y n ≠ i ) ⋅ [ − x n [ y n ] + log ∑ c = 1 C exp ( x n [ c ] ) ] l_n=\boldsymbol{w}[y_n]\cdot \mathbb{I}(y_n\neq i)\cdot[-\boldsymbol{x_n}[y_n]+\log\sum_{c=1}^C\exp(\boldsymbol{x_n}[c])] ln=w[yn]⋅I(yn=i)⋅[−xn[yn]+logc=1∑Cexp(xn[c])]
其中 w = ( w 1 , ⋯ , w C ) , x n = ( x n 1 , ⋯ , x n C ) \boldsymbol{w}=(w_1,\cdots,w_C),\;\boldsymbol{x_n}=(x_{n1},\cdots,x_{nC}) w=(w1,⋯,wC),xn=(xn1,⋯,xnC)。再令 X = ( x 1 ; ⋯ ; x N ) , y = ( y 1 , ⋯ , y C ) {\bf X}=(\boldsymbol{x_1};\cdots;\boldsymbol{x_N}),\;\boldsymbol{y}=(y_1,\cdots,y_C) X=(x1;⋯;xN),y=(y1,⋯,yC),则显然 X {\bf X} X 就是我们的 input
, y \boldsymbol{y} y 就是 target
,于是我们可以进行批量计算
( l 1 , ⋯ , l N ) = w [ y ] ∗ I ( y ≠ i ) ∗ ( − X [ range ( len ( y ) ) , y ] + log ( sum ( exp ( X ) , dim = 1 ) ) ) (l_1,\cdots,l_N)=\boldsymbol{w}[\boldsymbol{y}] *\mathbb{I}(\boldsymbol{y}\neq i)* (-{\bf X}[\text{range}(\text{len}(\boldsymbol{y})),\,\boldsymbol{y}]+\log(\text{sum}(\exp({\bf X}),\,\text{dim}=1))) (l1,⋯,lN)=w[y]∗I(y=i)∗(−X[range(len(y)),y]+log(sum(exp(X),dim=1)))
其中 ∗ * ∗ 代表按元素相乘。上式采用了广播机制。
class CrossEntropyLoss(nn.Module):
def __init__(self, weight=None, ignore_index=-100, reduction='mean'):
super().__init__()
self.weight = weight
self.ignore_index = ignore_index
self.reduction = reduction
def forward(self, inputs, target):
if self.weight is not None:
n_samples_weight = self.weight[target] # 每个样本的权重
else:
n_samples_weight = torch.ones_like(target).float() # 不提供权重则默认全为1
indicator = (target != self.ignore_index).long().float() # long()方法可以将布尔型张量转化成0-1张量
raw_loss = -inputs[torch.arange(len(target)), target] + torch.log(torch.sum(torch.exp(inputs), dim=1))
result = n_samples_weight * indicator * raw_loss
if self.reduction == 'mean':
return torch.sum(result) / n_samples_weight.dot(indicator)
elif self.reduction == 'sum':
return torch.sum(result)
else:
return result
输出结果与 PyTorch 官方的 nn.CrossEntropyLoss
的完全相同,这里不再展示,读者可自行验证。
边栏推荐
- AdaBoost use
- Codeworks round 681 (Div. 2) supplement
- Programming implementation of ROS learning 5-client node
- Halcon wood texture recognition
- Install the CPU version of tensorflow+cuda+cudnn (ultra detailed)
- Halcon clolor_ pieces. Hedv: classifier_ Color recognition
- Introduction Guide to stereo vision (4): DLT direct linear transformation of camera calibration [recommended collection]
- golang 基础 —— golang 向 mysql 插入的时间数据和本地时间不一致
- 一题多解,ASP.NET Core应用启动初始化的N种方案[上篇]
- Halcon blob analysis (ball.hdev)
猜你喜欢
Programming implementation of ROS learning 5-client node
Introduction Guide to stereo vision (5): dual camera calibration [no more collection, I charge ~]
Rebuild my 3D world [open source] [serialization-2]
牛顿迭代法(解非线性方程)
Rebuild my 3D world [open source] [serialization-3] [comparison between colmap and openmvg]
Applet (subcontracting)
IT冷知识(更新ing~)
Redis implements a high-performance full-text search engine -- redisearch
Count of C # LINQ source code analysis
Rebuild my 3D world [open source] [serialization-1]
随机推荐
Hello everyone, welcome to my CSDN blog!
Introduction Guide to stereo vision (4): DLT direct linear transformation of camera calibration [recommended collection]
嗨 FUN 一夏,与 StarRocks 一起玩转 SQL Planner!
Blue Bridge Cup provincial match simulation question 9 (MST)
Illustrated network: what is gateway load balancing protocol GLBP?
Return of missing persons
C#图像差异对比:图像相减(指针法、高速)
Meta tag details
AdaBoost use
Add discount recharge and discount shadow ticket plug-ins to the resource realization applet
Causes and appropriate analysis of possible errors in seq2seq code of "hands on learning in depth"
[code practice] [stereo matching series] Classic ad census: (4) cross domain cost aggregation
图解网络:什么是网关负载均衡协议GLBP?
Solutions of ordinary differential equations (2) examples
驾驶证体检医院(114---2 挂对应的医院司机体检)
Introduction Guide to stereo vision (2): key matrix (essential matrix, basic matrix, homography matrix)
Codeforces round 684 (Div. 2) e - green shopping (line segment tree)
Programming implementation of ROS learning 6 -service node
asp.net(c#)的货币格式化
Redis实现高性能的全文搜索引擎---RediSearch