当前位置:网站首页>pytorch中常用损失函数总结
pytorch中常用损失函数总结
2022-06-30 05:08:00 【武乐乐~】
前言
本文主要介绍pytorch中常用的损失函数API的使用。
1. 分类损失
1.1. nn.BCELoss()

nn.BCELoss()用于计算二分类问题,使用时采用默认初始化即可,即reduction='mean’是返回loss在所有样本上的均值。在forward方法中,所接受的 input和target必须一样的shape,且target是one-hot编码,而input需提前经过sigmoid处理。
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
# 二元交叉熵损失函数,只能处理二分类问题
# 假设处理 二分类问题,且批次=2
input = torch.Tensor([[-1,1],[1,2]]) # input: [2,2]
input = input.sigmoid()
# 转成one-hot
target = torch.Tensor([0,1]) # shape:[2]
onehot_target = torch.eye(2)[target.long(), :]
Loss = nn.BCELoss() # 采用默认初始化
loss1 = Loss(input, onehot_target)
loss2 = F.binary_cross_entropy(input, onehot_target) # 1.0167
1.2. nn.BCEWithLogitsLoss()
该损失函数就是集成了sigmoid的处理,即此时input是直接网络输出即可,不必人为加sigmoid处理。
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
# 二元交叉熵损失函数,只能处理二分类问题
# 假设处理 二分类问题,且批次=2
input = torch.Tensor([[-1,1],[1,2]]) # input: [2,2]
# 转成one-hot
target = torch.Tensor([0,1]) # shape:[2]
onehot_target = torch.eye(2)[target.long(), :]
Loss = nn.BCEWithLogitsLoss() # 采用默认初始化
loss1 = Loss(input, onehot_target)
loss2 = F.binary_cross_entropy_with_logits(input, onehot_target)
print(loss1, loss2) # [1.0167]
1.3. 多分类交叉熵损失函数
1)在解决多分类问题时,公式如下:
其中N是总样本个数,K是类别,pic表示第i个样本所属第c个类别。这么说比较抽象,假设现在需要手动实现上述公式的代码: 假设有3个样本(N=3),其中每个p假设经过softmax处理,概率和为1,并将label转成one-hot编码。
首先不要考虑最外层的求和符号,先计算内层的求和计算:L1,L2,L3,在计算外层的求和符号即可。
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
p = torch.Tensor([[0.2,0.3,0.5],[0.1,0.7,0.2],[0.4,0.5,0.1]])
label = torch.Tensor([0,1,2])
onehot = torch.eye(3)[label.long(), :]
# 分别计算每个样本的交叉熵
p = torch.log(p) # 取对数
loss = torch.sum(onehot * p)# 对应元素相乘并求和
# 在计算外层求和符号
loss = -loss / p.shape[0]
print(loss) # 1.4429
2)为了简化上述流程(label需one-hot=),torch用==nn.NLLLoss()==其进行了封装,将上述代码简化:
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
Loss = nn.NLLLoss()
p = torch.Tensor([[0.2,0.3,0.5],[0.1,0.7,0.2],[0.4,0.5,0.1]])
label = torch.Tensor([0,1,2]).long()
#onehot = torch.eye(3)[label.long(), :]
# 分别计算每个样本的交叉熵
p = torch.log(p) # 取对数
loss = Loss(p, label)
#loss = torch.sum(onehot * p)# 对应元素相乘并求和
# 在计算外层求和符号
#loss = -loss / p.shape[0]
print(loss) # 1.4429
3)上述过程还是不够简化,因为p需要softmax+log操作,于是,torch进一步封装,就是:
先不必管参数,直接使用:
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
p = torch.randn(4,3) # 网络直接输出,没有经过Softmax
label = torch.Tensor([0,1,2,0]).long() #
# 首先拿普通方法计算下
log_p = F.log_softmax(p)
Loss = nn.NLLLoss()
loss1 = Loss(log_p, label)
# 用CrossEP计算下
Loss = nn.CrossEntropyLoss()
loss2 = Loss(p, label)
print(loss1, loss2) #二者结果一致
简单做下总结: 交叉熵损失函数:log + softmax + one-hot的集大成者,此时pred只需是[N,C]未经Softmax处理的,label只需是[N]里面元素是正常类别标签即可。然后传入API就能得出交叉熵损失。
4)当然,这里有个额外参数需注意下:ignore_index,作用是忽略某个类别的损失。比如设置为0,就是去掉0这部分的损失值,并在非0元素上做损失的平均。
from math import log
import torch
import torch.nn as nn
import torch.nn.functional as F
p = torch.Tensor([[0.1, 0.2, 0.3],[0.4, 0.5, 0.6],[0.1,0.2,0.3]]) #[2,3]
label = torch.Tensor([0, 1, 1]).long() # [2]
# 现在假设去掉标签为0的损失
Loss = nn.CrossEntropyLoss(ignore_index=0)
loss3 = Loss(p, label)
print(loss3) # 1.1019
print('验证ignore_index')
p = F.softmax(p) # 对p进行softmax
onehot = torch.eye(3)[label.long(), :]
# 分别计算每个样本的交叉熵
p = torch.log(p)
v = (onehot * p)
loss = torch.sum(v[1:]) # 去掉标签为0的损失
# 在计算外层求和符号
loss = -loss / 2 # 2个非0,故/2
print(loss) #
1.4.Focal_loss
介绍完交叉熵损失后,不得不介绍下常用的Focal loss。首先看下focal loss的公式:

从公式上可以看出,实现focal loss首先实现CE(pt),即二维交叉熵损失函数,可以直接采用自带sigmoid操作的nn.BCEWithLogitsLoss(),且target需要one-hot编码。
在有了CE后,需要单独求下pt即可,注意此处需要人为加上sigmoid! 另外,论文中alpha_t的求解和pt一样:
这里贴下focal loss经典实现:
import torch
from torch import nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self,alpha=0.25,gamma=2.0,reduce='sum'):
super(FocalLoss,self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduce = reduce
def forward(self,classifications,targets):
# classifcation:[N,K]
# targets: [N,K]的one-hot编码
alpha = self.alpha
gamma = self.gamma
classifications = classifications.view(-1) # 不经过sigmoid的classification;
targets = targets.view(-1) # 应该是 one-hot
# ce_loss: 对应公式中 -log(pt),也就是普通的 交叉熵损失;--> 该函数接收未经sigmoid的函数;
ce_loss = F.binary_cross_entropy_with_logits(classifications, targets.float(), reduction="none")
#focal loss
p = torch.sigmoid(classifications) # 经过sigmoid
p_t = p * targets + (1 - p) * (1 - targets) # 计算pt
loss = ce_loss * ((1 - p_t) ** gamma) # -log(pt) * (1-pt) ** ganmma
if alpha >= 0:
# 对应公式中alpha_t控制损失的权重
alpha_t = alpha * targets + (1 - alpha) * (1 - targets) # 和pt求解过程一样
loss = alpha_t * loss # 最终focal loss
if self.reduce=='sum':
loss = loss.sum()
elif self.reduce=='mean':
loss = loss.mean()
else:
raise ValueError('reduce type is wrong!')
return loss
2. 回归损失
边栏推荐
- Four methods of unity ugui button binding events
- Writing unityshader with sublimetext
- 虚析构和纯虚析构
- Unity + hololens2 performance test
- Output directory of log files after unity3d packaging
- 力扣59. 螺旋矩阵 II
- Under what conditions does the Z-index attribute expire?
- Easyrecovery data recovery software recovers my photo and video data two years ago
- Tcp/ip protocol details Volume I (Reading Guide)
- 很紧张,第一天做软件测试,需要做什么?
猜你喜欢

Malignant bug: 1252 of unit MySQL export

Unity3d lookat parameter description

PWN入门(2)栈溢出基础

Database base (Study & review for self use)

Unity C trigonometric function, right triangle corner calculation

How does unity use mapbox to implement real maps in games?

Force buckle 209 Minimum length subarray

Unity is associated with vs. there is a compiler problem when opening

Basic operations of Oracle data

Create transfer generation point
随机推荐
Force buckle 977 Square of ordered array
炒美原油的国际交易平台如何能保障资金安全呢?
东塔攻防世界—xss绕过安全狗
Photon pun refresh hall room list
Unity + hololens common basic functions
Chapter 10 of OpenGL super classic (7th Edition) calculation shader
Important knowledge points in unity3d
UnityEngine. JsonUtility. The pit of fromjason()
Nestjs introduction and environment construction
Pytorchcnn image recognition and classification model training framework
Network communication problem locating steps
Unit asynchronous jump progress
Records of problems encountered in unity + hololens development
Ripple effect of mouse click (unity & shader)
PWN入门(2)栈溢出基础
Four methods of unity ugui button binding events
力扣(LeetCode)180. 连续出现的数字(2022.06.29)
Unity project hosting platform plasticscm (learn to use 1)
Pycharm database tool
JS 数组的排序 sort方法详解