当前位置:网站首页>【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
2022-07-01 09:03:00 【Enzo 想砸电脑】
一、意会的理解 交叉熵作为损失函数的意义
交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义
假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们需要定义一个损失函数,然后通过反向传播调整模型的权重,这里的 损失函数我们就选择 交叉熵损失函数啦~
nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
- class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值
上面铺垫完了,接下来,我们要拆解公式,理解公式了
1、首先,交叉熵损失函数公式中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax 将分类的结果做了归一化: e x e^x ex 先将数据映射到(0, 1] 的区间,再使所有分类概率相加的总和等于1。 经过softmax处理后,size不会变,每个值的意义是样本被分到这个分类的概率。
2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class],我们希望它的值是 100%
3、作为损失函数的意义是:当预测结果越接近真实值,损失函数的值越接近于0。
我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) 越接近0。
二、应用:
假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict 和 真实值label 的交叉熵损失,可以
import torch
import torch.nn as nn
# -----------------------------------------
# 定义数据: batch_size=4; 一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:
- softmax:对每张图片的分类结果做softmax, softmax详细介绍
- log:对上面的结果 取log
(步骤1 和 步骤2 可以合并为 nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
纯手撸版本
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- 小鸟识别APP
- Why is the Ltd independent station a Web3.0 website!
- jeecg 重启报40001
- How to manage fixed assets efficiently in one stop?
- Shell script - array definition and getting array elements
- LogBack
- 日常办公耗材管理解决方案
- 嵌入式工程师面试题3-硬件
- Shell script -if else statement
- Programming with C language: calculate with formula: e ≈ 1+1/1+ 1/2! …+ 1/n!, Accuracy is 10-6
猜你喜欢

Redis——Lettuce连接redis集群

安装Oracle EE
![[MFC development (16)] tree control](/img/b9/1de4330c0bd186cfe062b02478c058.png)
[MFC development (16)] tree control

Principles of Microcomputer - Introduction

小鸟识别APP

3. Detailed explanation of Modbus communication protocol
![[MFC development (17)] advanced list control list control](/img/e8/24c52cb51defc6c96b43c2ef3232ff.png)
[MFC development (17)] advanced list control list control

Jetson nano installs tensorflow GPU and problem solving

Understanding and implementation of AVL tree

Pain points and solutions of equipment management in large factories
随机推荐
Ranking list of domestic databases in February, 2022: oceanbase regained the "three consecutive increases", and gaussdb is expected to achieve the largest increase this month
In the middle of the year, where should fixed asset management go?
Redis -- lattice connects to redis cluster
毕业季,我想对你说
Dynamic proxy
【ESP 保姆级教程 预告】疯狂Node.js服务器篇 ——案例:ESP8266 + DHT11 +NodeJs本地服务+ MySQL数据库
TV size and viewing distance
Interrupt sharing variables with other functions and protection of critical resources
Redis——Lettuce连接redis集群
如何高效拉齐团队认知
NiO zero copy
Centos7 shell script one click installation of JDK, Mongo, Kafka, FTP, PostgreSQL, PostGIS, pgrouting
Screenshot tips
嵌入式工程师面试题3-硬件
[interview brush 101] linked list
Memory size end
Computer tips
What are the differences between the architecture a, R and m of arm V7, and in which fields are they applied?
Yolov3, 4, 5 and 6 Summary of target detection
C language student information management system