当前位置:网站首页>【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
2022-07-01 09:03:00 【Enzo 想砸电脑】
一、意会的理解 交叉熵作为损失函数的意义
交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义
假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们需要定义一个损失函数,然后通过反向传播调整模型的权重,这里的 损失函数我们就选择 交叉熵损失函数啦~
nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
- class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值
上面铺垫完了,接下来,我们要拆解公式,理解公式了
1、首先,交叉熵损失函数公式中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax 将分类的结果做了归一化: e x e^x ex 先将数据映射到(0, 1] 的区间,再使所有分类概率相加的总和等于1。 经过softmax处理后,size不会变,每个值的意义是样本被分到这个分类的概率。
2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class],我们希望它的值是 100%
3、作为损失函数的意义是:当预测结果越接近真实值,损失函数的值越接近于0。
我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) 越接近0。
二、应用:
假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict 和 真实值label 的交叉熵损失,可以
import torch
import torch.nn as nn
# -----------------------------------------
# 定义数据: batch_size=4; 一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:
- softmax:对每张图片的分类结果做softmax, softmax详细介绍
- log:对上面的结果 取log
(步骤1 和 步骤2 可以合并为 nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
纯手撸版本
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- Advanced C language pointer (Part 2)
- Nacos - Configuration Management
- Common interview questions for embedded engineers 2-mcu_ STM32
- Memory size end
- Win7 pyinstaller reports an error DLL load failed while importing after packaging exe_ Socket: parameter error
- Log4j 日志框架
- 个人装修笔记
- 毕业季,我想对你说
- 足球篮球体育比赛比分直播平台源码/app开发建设项目
- 3. Detailed explanation of Modbus communication protocol
猜你喜欢

Nacos - 配置管理

Principle and application of single chip microcomputer timer, serial communication and interrupt system

Nacos - 配置管理

Why is the Ltd independent station a Web3.0 website!

Pain points and solutions of equipment management in large factories

Only in China! Alicloud container service enters the Forrester leader quadrant

Jetson Nano 安装TensorFlow GPU及问题解决

Mysql 优化

Advanced level of C language pointer (Part 1)

Memory size end
随机推荐
ARM v7的体系结构A、R、M区别,分别应用在什么领域?
中考体育项目满分标准(深圳、安徽、湖北)
Bimianhongfu queren()
Set the type of the input tag to number, and remove the up and down arrows
【MFC开发(16)】树形控件Tree Control
Shell script -for loop and for int loop
【ESP 保姆级教程】疯狂毕设篇 —— 案例:基于阿里云、小程序、Arduino的温湿度监控系统
Jeecg restart alarm 40001
安装Oracle EE
Redis源码学习(29),压缩列表学习,ziplist.c(二)
动态代理
Shell script case in statement
Nacos - 配置管理
Glitch free clock switching technology
【ESP 保姆级教程】疯狂毕设篇 —— 案例:基于物联网的GY906红外测温门禁刷卡系统
Shell script - positional parameters (command line parameters)
FreeRTOS学习简易笔记
Shell script -read command: read data entered from the keyboard
Nacos - gestion de la configuration
Shell script echo command escape character