当前位置:网站首页>【深度学习】一文看尽Pytorch之十九种损失函数
【深度学习】一文看尽Pytorch之十九种损失函数
2022-07-04 18:40:00 【风度78】
仅作学术分享,不代表本公众号立场,侵权联系删除
转载于:作者:mingo_敏
原文链接:https://blog.csdn.net/shanglianlm/article/details/85019768
导读
本文汇总了十九种损失函数,介绍了其数学公式和代码实现方式,希望大家可以熟练掌握。
01
基本用法
criterion = LossCriterion() #构造函数有自己的参数
loss = criterion(x, y) #调用标准时也有参数
02
损失函数
2-1 L1范数损失 L1Loss
计算 output 和 target 之差的绝对值。
torch.nn.L1Loss(reduction='mean')
参数:
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-2 均方误差损失 MSELoss
计算 output 和 target 之差的均方差。
torch.nn.MSELoss(reduction='mean')
参数:
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-3 交叉熵损失 CrossEntropyLoss
当训练有 C 个类别的分类问题时很有效. 可选参数 weight 必须是一个1维 Tensor, 权重将被分配给各个类别. 对于不平衡的训练集非常有效。
在多分类任务中,经常采用 softmax 激活函数+交叉熵损失函数,因为交叉熵描述了两个概率分布的差异,然而神经网络输出的是向量,并不是概率分布的形式。所以需要 softmax激活函数将一个向量进行“归一化”成概率分布的形式,再采用交叉熵损失函数计算 loss。
torch.nn.CrossEntropyLoss(weight=None, ignore_index=-100, reduction='mean')
参数:
weight (Tensor, optional) – 自定义的每个类别的权重. 必须是一个长度为 C 的 Tensor
ignore_index (int, optional) – 设置一个目标值, 该目标值会被忽略, 从而不会影响到 输入的梯度。
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-4 KL 散度损失 KLDivLoss
计算 input 和 target 之间的 KL 散度。KL 散度可用于衡量不同的连续分布之间的距离, 在连续的输出分布的空间上(离散采样)上进行直接回归时很有效.
torch.nn.KLDivLoss(reduction='mean')
参数:
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-5 二进制交叉熵损失 BCELoss
二分类任务时的交叉熵计算函数。用于测量重构的误差, 例如自动编码机. 注意目标的值 t[i] 的范围为0到1之间.
torch.nn.BCELoss(weight=None, reduction='mean')
参数:
weight (Tensor, optional) – 自定义的每个 batch 元素的 loss 的权重. 必须是一个长度为 “nbatch” 的 的 Tensor
pos_weight(Tensor, optional) – 自定义的每个正样本的 loss 的权重. 必须是一个长度 为 “classes” 的 Tensor
2-6 BCEWithLogitsLoss
BCEWithLogitsLoss损失函数把 Sigmoid 层集成到了 BCELoss 类中. 该版比用一个简单的 Sigmoid 层和 BCELoss 在数值上更稳定, 因为把这两个操作合并为一个层之后, 可以利用 log-sum-exp 的 技巧来实现数值稳定.
torch.nn.BCEWithLogitsLoss(weight=None, reduction='mean', pos_weight=None)
参数:
weight (Tensor, optional) – 自定义的每个 batch 元素的 loss 的权重. 必须是一个长度 为 “nbatch” 的 Tensor
pos_weight(Tensor, optional) – 自定义的每个正样本的 loss 的权重. 必须是一个长度 为 “classes” 的 Tensor
2-7 MarginRankingLoss
torch.nn.MarginRankingLoss(margin=0.0, reduction='mean')
对于 mini-batch(小批量) 中每个实例的损失函数如下:
参数:
margin:默认值0
2-8 HingeEmbeddingLoss
torch.nn.HingeEmbeddingLoss(margin=1.0, reduction='mean')
对于 mini-batch(小批量) 中每个实例的损失函数如下:
参数:
margin:默认值1
2-9 多标签分类损失 MultiLabelMarginLoss
torch.nn.MultiLabelMarginLoss(reduction='mean')
对于mini-batch(小批量) 中的每个样本按如下公式计算损失:
2-10 平滑版L1损失 SmoothL1Loss
也被称为 Huber 损失函数。
torch.nn.SmoothL1Loss(reduction='mean')
其中
2-11 2分类的logistic损失 SoftMarginLoss
torch.nn.SoftMarginLoss(reduction='mean')
2-12 多标签 one-versus-all 损失 MultiLabelSoftMarginLoss
torch.nn.MultiLabelSoftMarginLoss(weight=None, reduction='mean')
2-13 cosine 损失 CosineEmbeddingLoss
torch.nn.CosineEmbeddingLoss(margin=0.0, reduction='mean')
参数:
margin:默认值0
2-14 多类别分类的hinge损失 MultiMarginLoss
torch.nn.MultiMarginLoss(p=1, margin=1.0, weight=None, reduction='mean')
参数:
p=1或者2 默认值:1
margin:默认值1
2-15 三元组损失 TripletMarginLoss
torch.nn.TripletMarginLoss(margin=1.0, p=2.0, eps=1e-06, swap=False, reduction='mean')
其中:
2-16 连接时序分类损失 CTCLoss
CTC连接时序分类损失,可以对没有对齐的数据进行自动对齐,主要用在没有事先对齐的序列化数据训练上。比如语音识别、ocr识别等等。
torch.nn.CTCLoss(blank=0, reduction='mean')
参数:
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-17 负对数似然损失 NLLLoss
负对数似然损失. 用于训练 C 个类别的分类问题.
torch.nn.NLLLoss(weight=None, ignore_index=-100, reduction='mean')
参数:
weight (Tensor, optional) – 自定义的每个类别的权重. 必须是一个长度为 C 的 Tensor
ignore_index (int, optional) – 设置一个目标值, 该目标值会被忽略, 从而不会影响到 输入的梯度.
2-18 NLLLoss2d
对于图片输入的负对数似然损失. 它计算每个像素的负对数似然损失.
torch.nn.NLLLoss2d(weight=None, ignore_index=-100, reduction='mean')
参数:
weight (Tensor, optional) – 自定义的每个类别的权重. 必须是一个长度为 C 的 Tensor
reduction-三个值,none: 不使用约简;mean:返回loss和的平均值;sum:返回loss的和。默认:mean。
2-19 PoissonNLLLoss
目标值为泊松分布的负对数似然损失
torch.nn.PoissonNLLLoss(log_input=True, full=False, eps=1e-08, reduction='mean')
参数:
log_input (bool, optional) – 如果设置为 True , loss 将会按照公 式 exp(input) - target * input 来计算, 如果设置为 False , loss 将会按照 input - target * log(input+eps) 计算.
full (bool, optional) – 是否计算全部的 loss, i. e. 加上 Stirling 近似项 target * log(target) - target + 0.5 * log(2 * pi * target).
eps (float, optional) – 默认值: 1e-8
参考资料
http://www.voidcn.com/article/p-rtzqgqkz-bpg.html
往期精彩回顾
适合初学者入门人工智能的路线及资料下载(图文+视频)机器学习入门系列下载中国大学慕课《机器学习》(黄海广主讲)机器学习及深度学习笔记等资料打印《统计学习方法》的代码复现专辑机器学习交流qq群955171419,加入微信群请扫码
边栏推荐
- 应用实践 | 蜀海供应链基于 Apache Doris 的数据中台建设
- QT writing the Internet of things management platform 38- multiple database support
- 黑马程序员-软件测试--08阶段2-linux和数据库-23-30-进程端口相关,修改文件权限,端口号信息的获取,程序和进程相关操作,linux命令案例
- 1007 maximum subsequence sum (25 points) (PAT class a)
- 1005 spell it right (20 points) (pat a)
- 1008 elevator (20 points) (PAT class a)
- [problem] Druid reports exception SQL injection violation, part always true condition not allow solution
- Kotlin condition control
- C语言-入门-基础-语法-流程控制(七)
- Chrome开发工具:VMxxx文件是什么鬼
猜你喜欢
YOLOv5s-ShuffleNetV2
C server log module
abc229 总结(区间最长连续字符 图的联通分量计数)
关于联邦学习和激励的相关概念(1)
做社交媒体营销应该注意些什么?Shopline卖家的成功秘笈在这里!
Application practice | Shuhai supply chain construction of data center based on Apache Doris
Crystal optoelectronics: ar-hud products of Chang'an dark blue sl03 are supplied by the company
node强缓存和协商缓存实战示例
Small hair cat Internet of things platform construction and application model
[problem] Druid reports exception SQL injection violation, part always true condition not allow solution
随机推荐
In the first month of its launch, the tourist praise rate of this campsite was as high as 99.9%! How did he do it?
水晶光电:长安深蓝SL03的AR-HUD产品由公司供应
[QNX hypervisor 2.2 user manual]6.3.1 factory page and control page
C # better operation mongodb database
Template_ Judging prime_ Square root / six prime method
Actual combat simulation │ JWT login authentication
HDU 1097 A hard puzzle
Niuke Xiaobai month race 7 e applese's super ability
Delete the characters with the least number of occurrences in the string [JS, map sorting, regular]
Lingyun going to sea | Wenhua online & Huawei cloud: creating a new solution for smart teaching in Africa
[Beijing Xunwei] i.mx6ull development board porting Debian file system
Qt编写物联网管理平台38-多种数据库支持
Kotlin classes and objects
Niuke Xiaobai month race 7 who is the divine Archer
[graduation season] green ant new fermented grains wine, red mud small stove. If it snows late, can you drink a cup?
15million employees are easy to manage, and the cloud native database gaussdb makes HR office more efficient
Wireshark network packet capture
C语言-入门-基础-语法-流程控制(七)
六石编程学:关于代码,有六个得意
上线首月,这家露营地游客好评率高达99.9%!他是怎么做到的?