当前位置:网站首页>【pytorch】nn. Crossentropyloss() and nn NLLLoss()
【pytorch】nn. Crossentropyloss() and nn NLLLoss()
2022-07-01 09:08:00 【Enzo tried to smash the computer】
One 、 Understanding of meaning The significance of cross entropy as a loss function
Cross entropy loss is often used for Multi classification function , Next, we understand the meaning of cross entropy as a loss function by disassembling the formula of cross entropy
Suppose we are making a n The problem of classification , The output of the model prediction is [ x 1 , x 2 , x 3 , . . . . , x n ] [x_1, x_2, x_3, ...., x_n] [x1,x2,x3,....,xn]
then , We need to define a loss function , Then the weight of the model is adjusted by back propagation , there We choose the loss function Cross entropy loss function ~
nn.CrossEntropyLoss() The formula of is :
l o s s ( x , c l a s s ) = − l o g ( e x [ c l a s s ] ∑ j e x j ) = − x [ c l a s s ] + l o g ( ∑ j e x j ) loss(x, class) = -log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) = -x_{[class]} + log(\sum_j e^{x_{j}}) loss(x,class)=−log(∑jexjex[class])=−x[class]+log(j∑exj)
- x It's the prediction , It's a vector , The number of elements should be guaranteed by the model , Guarantee as many as the classification number
- class Represents the actual label of this sample , such as , The sample actually belongs to classification 2, that class=2
x [ c l a s s ] x_{[class]} x[class] Namely x 2 x_2 x2, Is to take the second element in the test result vector , That is, take the predicted value corresponding to its real classification
It's over , Next , We need to disassemble the formula , Understand the formula
1、 First , The cross entropy loss function formula contains a basic part : s o f t m a x ( x i ) = e x i ∑ j e x j softmax(x_i) = \frac{e^{x_i}}{\sum_je^{x_{j}}} softmax(xi)=∑jexjexi
softmax The classification results are normalized : e x e^x ex First map the data to (0, 1] The range of , Then make the sum of all classification probabilities equal 1. after softmax After processing ,size It won't change , The meaning of each value is the probability that the sample is assigned to this classification .
2、 We want to make the prediction result , The probability of the value of the real classification is close to 100%. We take the value of the real classification :
e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class], We want its value to be 100%
3、 The meaning of being a loss function is : When the prediction results are more Close to the real value , The closer the value of the loss function is to 0.
We put e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] take log, Take the opposite again , Can guarantee when e x [ c l a s s ] ∑ j e x j \frac{e^{x_{[class]}}}{\sum_je^{x_{j}}} ∑jexjex[class] The more close to 100%, l o s s = − l o g ( e x [ c l a s s ] ∑ j e x j ) loss=-log(\frac{e^{x_{[class]}}}{\sum_je^{x_{j}}}) loss=−log(∑jexjex[class]) The closer the 0.
Two 、 application :
Suppose there is 4 A picture , Or say batch_ size=4. We need to put this 4 Pictures are classified into 5 Categories , for instance : bird , Dog , cat , automobile , ship
After network calculation , We got the prediction results :predict,size by [4, 5]
Its real label is label,size by [4]
Next use nn.CrossEntropyLoss() Calculation Predicted results predict and True value label The cross entropy loss of , Sure
import torch
import torch.nn as nn
# -----------------------------------------
# Defining data : batch_size=4; Altogether 5 A classification
# label.size() : torch.Size([4])
# predict.size(): torch.Size([4, 5])
# -----------------------------------------
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
print(predict)
print(label)
# -----------------------------------------
# Call function directly nn.CrossEntropyLoss() Calculation Loss
# -----------------------------------------
criterion = nn.CrossEntropyLoss()
loss = criterion(predict, label)
print(loss)

nn.CrossEntropyLoss() It can be disassembled as follows 3 A step , Or it can be described as follows 3 Operation replacement , The result is the same :
- softmax: Classify the results of each picture softmax, softmax Detailed introduction
- log: For the above results take log
( step 1 and step 2 Can be combined into nn.logSoftmax() ) - NLL:nn.NLLLoss(a, b) The operation of is from a Remove from b The corresponding value (b What's in store is index value ), Then remove the minus sign ( Take the opposite ), Then sum and take the mean
import torch
import torch.nn as nn
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
softmax = nn.Softmax(dim=1)
nll = nn.NLLLoss()
temp1 = softmax(predict)
temp2 = torch.log(temp1)
output = nll(temp2, label)
print(output) # tensor(1.5230)
Hand only version
import torch
torch.manual_seed(100)
predict = torch.rand(4, 5)
label = torch.tensor([4, 3, 3, 2])
# softmax
temp1 = torch.exp(predict) / torch.sum(torch.exp(predict), dim=1, keepdim=True)
# log
temp2 = torch.log(temp1)
# nll
temp3 = torch.gather(temp2, dim=1, index=label.view(-1, 1))
temp4 = -temp3
output = torch.mean(temp4)
print(output) # tensor(1.5230)
边栏推荐
- In the middle of the year, where should fixed asset management go?
- Shell脚本-while循环详解
- Redis source code learning (29), compressed list learning, ziplist C (II)
- Shell脚本-if else语句
- The jar package embedded with SQLite database is deployed by changing directories on the same machine, and the newly added database records are gone
- 【ESP 保姆级教程 预告】疯狂Node.js服务器篇 ——案例:ESP8266 + DHT11 +NodeJs本地服务+ MySQL数据库
- 【检测技术课案】简易数显电子秤的设计与制作
- 类加载
- nacos简易实现负载均衡
- Shell脚本-case in 和正则表达式
猜你喜欢

Football and basketball game score live broadcast platform source code /app development and construction project

Why is the Ltd independent station a Web3.0 website!

猿人学第20题(题目会不定时更新)

Principles of Microcomputer - internal and external structure of microprocessor

AVL树的理解和实现

Principle and application of single chip microcomputer timer, serial communication and interrupt system

Jetson Nano 安装TensorFlow GPU及问题解决

Bimianhongfu queren()

Pain points and solutions of equipment management in large factories

【pytorch】softmax函数
随机推荐
Ape anthropology topic 20 (the topic will be updated from time to time)
Reproduced Xray - cve-2017-7921 (unauthorized access by Hikvision)
Shell脚本-case in 和正则表达式
Daily practice of C language - day 80: currency change
LogBack
【pytorch】nn.AdaptiveMaxPool2d
Programming with C language: calculate with formula: e ≈ 1+1/1+ 1/2! …+ 1/n!, Accuracy is 10-6
Input标签的type设置为number,去掉上下箭头
The jar package embedded with SQLite database is deployed by changing directories on the same machine, and the newly added database records are gone
Promise asynchronous programming
pcl_viewer命令
Serialization, listening, custom annotation
Glitch free clock switching technology
JCL 和 SLF4J
jeecg 重启报40001
Can diffusion models be regarded as an autoencoder?
Differences among tasks, threads and processes
Meituan machine test in 2022
Phishing identification app
Software Engineer Interview Question brushing website and experience method