当前位置:网站首页>【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
2022-07-01 09:03:00 【Enzo 想砸电脑】
一、意会的理解 交叉熵作为损失函数的意义
交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义
假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们需要定义一个损失函数,然后通过反向传播调整模型的权重,这里的 损失函数我们就选择 交叉熵损失函数啦~
nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
- class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值
上面铺垫完了,接下来,我们要拆解公式,理解公式了
1、首先,交叉熵损失函数公式中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax 将分类的结果做了归一化: e x e^x ex 先将数据映射到(0, 1] 的区间,再使所有分类概率相加的总和等于1。 经过softmax处理后,size不会变,每个值的意义是样本被分到这个分类的概率。
2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class],我们希望它的值是 100%
3、作为损失函数的意义是:当预测结果越接近真实值,损失函数的值越接近于0。
我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) 越接近0。
二、应用:
假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict 和 真实值label 的交叉熵损失,可以
import torch
import torch.nn as nn
# -----------------------------------------
# 定义数据: batch_size=4; 一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:
- softmax:对每张图片的分类结果做softmax, softmax详细介绍
- log:对上面的结果 取log
(步骤1 和 步骤2 可以合并为 nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
纯手撸版本
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- Phishing identification app
- 软件工程师面试刷题网站、经验方法
- FAQ | FAQ for building applications for large screen devices
- In the middle of the year, where should fixed asset management go?
- 【MFC开发(16)】树形控件Tree Control
- Microcomputer principle - bus and its formation
- 中考体育项目满分标准(深圳、安徽、湖北)
- Shell脚本-字符串
- Summary of reptile knowledge points
- Foundation: 2 The essence of image
猜你喜欢

Insert mathematical formula in MD document and mathematical formula in typora

How can enterprises and developers take the lead in the outbreak of cloud native landing?

Glitch free clock switching technology

安装Oracle EE

Computer tips

Jetson Nano 安装TensorFlow GPU及问题解决

嵌入式工程师面试题3-硬件

TV size and viewing distance

Vsync+ triple cache mechanism +choreographer

Only in China! Alicloud container service enters the Forrester leader quadrant
随机推荐
C语言学生信息管理系统
Dynamic proxy
嵌入式工程师面试题3-硬件
Leetcode daily question brushing record --540 A single element in an ordered array
Key points of NFT supervision and overseas policies
ARM v7的体系结构A、R、M区别,分别应用在什么领域?
Principles of Microcomputer - Introduction
Flink面试题
嵌入式工程师面试-常问问题集
Vsync+ triple cache mechanism +choreographer
类加载
NiO zero copy
Shell script -read command: read data entered from the keyboard
小鸟识别APP
Nacos - 配置管理
pcl_viewer命令
如何一站式高效管理固定资产?
Principle and application of single chip microcomputer timer, serial communication and interrupt system
AVL树的理解和实现
1. Connection between Jetson and camera