当前位置:网站首页>【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
2022-07-01 09:03:00 【Enzo 想砸电脑】
一、意会的理解 交叉熵作为损失函数的意义
交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义
假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们需要定义一个损失函数,然后通过反向传播调整模型的权重,这里的 损失函数我们就选择 交叉熵损失函数啦~
nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
- class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值
上面铺垫完了,接下来,我们要拆解公式,理解公式了
1、首先,交叉熵损失函数公式中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax 将分类的结果做了归一化: e x e^x ex 先将数据映射到(0, 1] 的区间,再使所有分类概率相加的总和等于1。 经过softmax处理后,size不会变,每个值的意义是样本被分到这个分类的概率。
2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class],我们希望它的值是 100%
3、作为损失函数的意义是:当预测结果越接近真实值,损失函数的值越接近于0。
我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) 越接近0。
二、应用:
假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict 和 真实值label 的交叉熵损失,可以
import torch
import torch.nn as nn
# -----------------------------------------
# 定义数据: batch_size=4; 一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:
- softmax:对每张图片的分类结果做softmax, softmax详细介绍
- log:对上面的结果 取log
(步骤1 和 步骤2 可以合并为 nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
纯手撸版本
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- ARM v7的体系结构A、R、M区别,分别应用在什么领域?
- Principles of Microcomputer - Introduction
- Insert mathematical formula in MD document and mathematical formula in typora
- 序列化、监听、自定义注解
- Full mark standard for sports items in the high school entrance examination (Shenzhen, Anhui and Hubei)
- Shell script - definition, assignment and deletion of variables
- I use flask to write the website "one"
- Only in China! Alicloud container service enters the Forrester leader quadrant
- Shell脚本-echo命令 转义符
- 小鸟识别APP
猜你喜欢

如何解决固定资产管理和盘点的难题?

How to manage fixed assets well? Easy to point and move to provide intelligent solutions

Do you know how data is stored? (C integer and floating point)

足球篮球体育比赛比分直播平台源码/app开发建设项目

Ape anthropology topic 20 (the topic will be updated from time to time)

电视机尺寸与观看距离

如何一站式高效管理固定资产?

Insert mathematical formula in MD document and mathematical formula in typora

Phishing identification app

Nacos - Configuration Management
随机推荐
How can enterprises and developers take the lead in the outbreak of cloud native landing?
Understanding and implementation of AVL tree
Redis -- lattice connects to redis cluster
Shell script -select in loop
Shell脚本-for循环和for int循环
In the middle of the year, where should fixed asset management go?
钓鱼识别app
Shell script -if else statement
FAQ | FAQ for building applications for large screen devices
【ESP 保姆级教程】疯狂毕设篇 —— 案例:基于阿里云、小程序、Arduino的温湿度监控系统
Why is the Ltd independent station a Web3.0 website!
Leetcode daily question brushing record --540 A single element in an ordered array
Jetson nano installs tensorflow GPU and problem solving
FreeRTOS学习简易笔记
动态代理
Flink面试题
C language student information management system
Principles of Microcomputer - internal and external structure of microprocessor
I use flask to write the website "one"
美团2022年机试