当前位置:网站首页>cross entropy loss = log softmax + nll loss
cross entropy loss = log softmax + nll loss
2022-06-26 05:32:00 【wujpbb7】
代码如下:
import torch
logits = torch.randn(3,4,requires_grad=True)
labels = torch.LongTensor([1,0,2])
print('logits={}, labels={}'.format(logits,labels))
# 直接计算交叉熵(cross entropy loss)
def calc_ce_loss1(logits, labels):
ce_loss = torch.nn.CrossEntropyLoss()
loss = ce_loss(logits, labels)
return loss
# 分解计算交叉熵(cross entropy loss = log softmax + nll loss)
def calc_ce_loss2(logits, labels):
log_softmax = torch.nn.LogSoftmax(dim=1)
nll_loss = torch.nn.NLLLoss()
logits_ls = log_softmax(logits)
loss = nll_loss(logits_ls, labels)
return loss
loss1 = calc_ce_loss1(logits, labels)
print('loss1={}'.format(loss1))
loss2 = calc_ce_loss2(logits, labels)
print('loss2={}'.format(loss2))
# 增加 temperature
temperature = 0.05
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))
temperature = 2
logits_t = logits / temperature
loss1 = calc_ce_loss1(logits_t, labels)
print('t={}, loss1={}'.format(temperature, loss1))
loss2 = calc_ce_loss2(logits_t, labels)
print('t={}, loss2={}'.format(temperature, loss2))输出如下:
logits=tensor([[-0.7441, -2.3802, -0.1708, 0.5020],
[ 0.3381, -0.3981, 2.2979, 0.6773],
[-0.5372, -0.4489, -0.0680, 0.4889]], requires_grad=True), labels=tensor([1, 0, 2])
loss1=2.399930000305176
loss2=2.399930000305176
t=0.05, loss1=35.99229431152344
t=0.05, loss2=35.99229431152344
t=2, loss1=1.8117588758468628
t=2, loss2=1.8117588758468628边栏推荐
- Leetcode513. Find the value in the lower left corner of the tree
- Describe an experiment of Kali ARP in LAN
- pytorch(网络模型)
- Owasp-top10 in 2021
- Gd32f3x0 official PWM drive has a small positive bandwidth (inaccurate timing)
- The wechat team disclosed that the wechat interface is stuck with a super bug "15..." The context of
- cartographer_local_trajectory_builder_2d
- 循环位移
- Daily production training report (15)
- How to make your big file upload stable and fast?
猜你喜欢
随机推荐
About abstact and virtual
C XX management system
小小面试题之GET和POST的区别
旧情书
Talk 5 wireless communication
data = self._ data_ queue. get(timeout=timeout)
Henkel database custom operator '~~‘
PHP 2D / multidimensional arrays are sorted in ascending and descending order according to the specified key values
Excellent learning ability is your only sustainable competitive advantage
基于SDN的DDoS攻击缓解
FindControl的源代码
Yunqi lab recommends experience scenarios this week, free cloud learning
Some doubts about ARP deception experiment
Daily production training report (17)
Uni app ceiling fixed style
Using Jenkins to perform testng+selenium+jsup automated tests and generate extendreport test reports
【ARM】在NUC977上搭建基于boa的嵌入式web服务器
虚拟项目失败感想
Mise en file d'attente des messages en utilisant jedis Listening redis stream
Anaconda creates tensorflow environment








