当前位置:网站首页>【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
【pytorch】nn.CrossEntropyLoss() 与 nn.NLLLoss()
2022-07-01 09:03:00 【Enzo 想砸电脑】
一、意会的理解 交叉熵作为损失函数的意义
交叉熵损失多用于 多分类函数,下面我们通过拆解交叉熵的公式来理解其作为损失函数的意义
假设我们在做一个 n分类的问题,模型预测的输出结果是 [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
然后,我们需要定义一个损失函数,然后通过反向传播调整模型的权重,这里的 损失函数我们就选择 交叉熵损失函数啦~
nn.CrossEntropyLoss() 的公式为:
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x 是预测结果,是一个向量,其元素个数是需要由模型保证的,保证和分类数一样多
- class 表示这个样本的实际标签,比如,样本实际属于分类2,那么class=2
x [ c l a s s ] x_{[class]} x[class] 就是 x 2 x_2 x2,就是取测试结果向量中的第二个元素,也就是取其真实分类对应的那个预测值
上面铺垫完了,接下来,我们要拆解公式,理解公式了
1、首先,交叉熵损失函数公式中包含了一个最基础的部分: s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax 将分类的结果做了归一化: e x e^x ex 先将数据映射到(0, 1] 的区间,再使所有分类概率相加的总和等于1。 经过softmax处理后,size不会变,每个值的意义是样本被分到这个分类的概率。
2、我们想要使预测结果中,真实分类的那个值的概率接近 100%。 我们取出真实分类的那个值:
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class],我们希望它的值是 100%
3、作为损失函数的意义是:当预测结果越接近真实值,损失函数的值越接近于0。
我们把 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 取log,再取反,就能保证当 e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] 越接近于100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) 越接近0。
二、应用:
假设有4张图片,或者说batch_ size=4。我们需要把这4张图片分类到5个类别上去,比如说:鸟,狗,猫,汽车,船
经过网络计算后,我们得到了预测结果:predict,size为[4, 5]
其真实标签为 label,size为 [4]
接下来使用 nn.CrossEntropyLoss() 计算 预测结果predict 和 真实值label 的交叉熵损失,可以
import torch
import torch.nn as nn
# -----------------------------------------
# 定义数据: batch_size=4; 一共有5个分类
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# 直接调用函数 nn.CrossEntropyLoss() 计算 Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() 可以拆解成如下3个步骤,或者说可以由如下3个操作替换,其运算结果一毛一样:
- softmax:对每张图片的分类结果做softmax, softmax详细介绍
- log:对上面的结果 取log
(步骤1 和 步骤2 可以合并为 nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) 的操作是从a 中取出b对应的那个值(b中存的是 index值),再去掉负号(取反),然后求和取均值
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
纯手撸版本
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- How to manage fixed assets well? Easy to point and move to provide intelligent solutions
- Nacos - 服务发现
- Nacos - Configuration Management
- 日常办公耗材管理解决方案
- Shell script case in statement
- Public network cluster intercom +gps visual tracking | help the logistics industry with intelligent management and scheduling
- Shell脚本-位置参数(命令行参数)
- 记一次redis超时
- 嵌入式工程师面试题3-硬件
- 大型工厂设备管理痛点和解决方案
猜你喜欢

集团公司固定资产管理的痛点和解决方案
![[interview brush 101] linked list](/img/52/d159bc66c0dbc44c1282a96cf6b2fd.png)
[interview brush 101] linked list

AVL树的理解和实现

Understanding and implementation of AVL tree

FAQ | FAQ for building applications for large screen devices

I use flask to write the website "one"

Centos7 shell script one click installation of JDK, Mongo, Kafka, FTP, PostgreSQL, PostGIS, pgrouting

Redis——Lettuce连接redis集群

An overview of the design of royalties and service fees of mainstream NFT market platforms

Jetson Nano 安装TensorFlow GPU及问题解决
随机推荐
IT 技术电子书 收藏
Shell script -for loop and for int loop
安装Oracle EE
Shell script -if else statement
Shell script - definition, assignment and deletion of variables
Shell script -while loop explanation
Foundation: 3 Opencv getting started images and videos
Shell script -read command: read data entered from the keyboard
Shell script - positional parameters (command line parameters)
I use flask to write the website "one"
Which method is good for the management of fixed assets of small and medium-sized enterprises?
DataBinding源码分析
FAQ | FAQ for building applications for large screen devices
Principles of Microcomputer - internal and external structure of microprocessor
中断与其他函数共享变量、临界资源的保护
Shell script -select in loop
Serialization, listening, custom annotation
Shell脚本-变量的定义、赋值和删除
Principles of Microcomputer - Introduction
序列化、监听、自定义注解