当前位置:网站首页>Pytorch学习(二)
Pytorch学习(二)
2022-06-30 02:23:00 【马少爷】
一、Pycharm查看函数参数和用法
1、 使用右键查看函数信息
1.1. 详细参数
鼠标放置在函数上:右键—>Go To—>Declaration or Usages 便会跳转到函数的源码。也可以使用快捷键 Ctrl+B
1.2. 函数使用情况
鼠标放置在函数上:右键—>Find Usages 便会在控制台输出该函数的使用情况。也可以使用快捷键 Alt+F7

2. 使用Ctrl查看函数信息
2.1. 详细参数
按住Ctrl将鼠标放在需要查看的函数上,便会出现该函数所需参数等简略信息。如需查看详细参数鼠标点击函数,会直接跳转到函数的源码。
二、nn.Dropout
dropout是Hinton老爷子提出来的一个用于训练的trick。在pytorch中,除了原始的用法以外,还有数据增强的用法(后文提到)。
首先要知道,dropout是专门用于训练的。在推理阶段,则需要把dropout关掉,而model.eval()就会做这个事情。
原文链接: https://arxiv.org/abs/1207.0580
通常意义的dropout解释为:在训练过程的前向传播中,让每个神经元以一定概率p处于不激活的状态。以达到减少过拟合的效果。
然而,在pytorch中,dropout有另一个用法。如果把dropout加在输入张量上:
x = torch.randn(20, 16)
dropout = nn.Dropout(p=0.2)
x_drop = dropout(x)
1.Dropout是为了防止过拟合而设置的
2.Dropout顾名思义有丢掉的意思
3.nn.Dropout(p = 0.3) # 表示每个神经元有0.3的可能性不被激活
4.Dropout只能用在训练部分而不能用在测试部分
5.Dropout一般用在全连接神经网络映射层之后,如代码的nn.Linear(20, 30)之后
import torch
import torch.nn as nn
a = torch.randn(4, 4)
print(a)
"""
tensor([[ 1.2615, -0.6423, -0.4142, 1.2982],
[ 0.2615, 1.3260, -1.1333, -1.6835],
[ 0.0370, -1.0904, 0.5964, -0.1530],
[ 1.1799, -0.3718, 1.7287, -1.5651]])
"""
dropout = nn.Dropout()
b = dropout(a)
print(b)
"""
tensor([[ 2.5230, -0.0000, -0.0000, 2.5964],
[ 0.0000, 0.0000, -0.0000, -0.0000],
[ 0.0000, -0.0000, 1.1928, -0.3060],
[ 0.0000, -0.7436, 0.0000, -3.1303]])
"""
由以上代码可知Dropout还可以将部分tensor中的值置为0
三、BatchNorm1d
torch.nn.BatchNorm1d(num_features, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
num_features – 特征维度
eps – 为数值稳定性而加到分母上的值。
momentum – 移动平均的动量值。
affine – 一个布尔值,当设置为真时,此模块具有可学习的仿射参数。
四、nn.CrossEntropyLoss()与NLLLoss
在图片单标签分类时,输入m张图片,输出一个mN的Tensor,其中N是分类个数。比如输入3张图片,分三类,最后的输出是一个33的Tensor,举个例子:
nn.CrossEntropyLoss()函数计算交叉熵损失
用法:
# output是网络的输出,size=[batch_size, class]
#如网络的batch size为128,数据分为10类,则size=[128, 10]
# target是数据的真实标签,是标量,size=[batch_size]
#如网络的batch size为128,则size=[128]
crossentropyloss=nn.CrossEntropyLoss()
crossentropyloss_output=crossentropyloss(output,target)
注意,使用nn.CrossEntropyLoss()时,不需要现将输出经过softmax层,否则计算的损失会有误,即直接将网络输出用来计算损失即可
nn.CrossEntropyLoss()的计算公式为:
其中x是网络的输出向量,class是真实标签
举个例子,一个三分类网络对某个输入样本的输出为[-0.7715, -0.6205,-0.2562],该样本的真实标签为0,则用nn.CrossEntropyLoss()计算的损失为:
NLLLoss
在图片单标签分类时,输入m张图片,输出一个mN的Tensor,其中N是分类个数。比如输入3张图片,分三类,最后的输出是一个33的Tensor,举个例子:
input = torch.randn(3,3)
print('input', input)

第123行分别是第123张图片的结果,假设第123列分别是猫、狗和猪的分类得分。
然后对每一行使用Softmax,这样可以得到每张图片的概率分布。
sm = nn.Softmax(dim=1)
output = sm(input)
print('output', output)

这里dim的意思是计算Softmax的维度,这里设置dim=1,可以看到每一行的加和为1。比如第一行之和=1。
sm = nn.Softmax(dim=0)
output2= sm(input)
print('output2', output2)

如果设置dim=0,就是一列的和为1。比如第一列之和=1。
我们这里一张图片是一行,所以dim应该设置为1。
然后对Softmax的结果取自然对数:
print(torch.log(sm(input)))

Softmax后的数值都在0~1之间,所以ln之后值域是负无穷到0。
NLLLoss的结果就是把上面的输出与Label对应的那个值拿出来,再去掉负号,再求均值。
假设我们现在Target是[0,2,1](第一张图片是猫,第二张是猪,第三张是狗)。第一行取第0个元素,第二行取第2个,第三行取第1个,去掉负号,结果是:[0.4155,1.0945,1.5285]。再求个均值,结果是:
loss = nn.NLLLoss()
target = torch.tensor([0,2,1])
LOS = loss(torch.log(sm(input)),target)
print('LOS', LOS)

CrossEntropyLoss就是把以上Softmax–Log–NLLLoss合并成一步,我们用刚刚随机出来的input直接验证一下结果是不是1.0128:
loss = nn.CrossEntropyLoss()
target = torch.tensor([0,2,1])
LOS2 = loss(input,target)
print('LOS2', LOS2)

边栏推荐
- Radware warns about the next round of DDoS Attacks
- DDoS threat situation gets worse
- DHU programming exercise
- Dynamic SQL
- How does payment splitting help B2B bulk commodity transactions?
- 如何制作CSR(Certificate Signing Request)文件?
- 一种跳板机的实现思路
- How to display all keys through redis cli- How to show ALL keys through redis-cli?
- 【银河麒麟V10】【桌面】火狐浏览器设置主页不生效
- How do PMP candidates respond to the new exam outline? Look!
猜你喜欢

Matlab 2012a drawing line segment with arrow
![[naturallanguageprocessing] [multimodality] ofa: unified architecture, tasks and modes through a simple sequence to sequence learning framework](/img/c9/7be54c428212d7226cbbbb4800fcdb.png)
[naturallanguageprocessing] [multimodality] ofa: unified architecture, tasks and modes through a simple sequence to sequence learning framework

论文回顾:Playful Palette: An Interactive Parametric Color Mixer for Artists
![[论]【DSTG】Dynamic SpatiotemporalGraph Convolutional Neural Networks for Traffic Data Imputation](/img/c3/f9d6399c931a006ca295bb1e3ac427.png)
[论]【DSTG】Dynamic SpatiotemporalGraph Convolutional Neural Networks for Traffic Data Imputation

主流CA吊销俄罗斯数字证书启示:升级国密算法SSL证书,助力我国网络安全自主可控
![[MySQL 05] SUSE 12 SP5 modifies the MySQL password for the first time after installing MySQL](/img/37/d24c9e5fad606d2623900ad018b6af.png)
[MySQL 05] SUSE 12 SP5 modifies the MySQL password for the first time after installing MySQL

Illustration Google V8 19: asynchronous programming (II): how does V8 implement async/await?

Heap sort

What problems can cloud storage architecture solve for Devops?

VScode如何Debug(调试)进入标准库文件/第三方包源码
随机推荐
DDoS threat situation gets worse
As VoIP became the target, DDoS attacks surged by 35% in the third quarter
widget使用setImageViewBitmap方法设置bug分析
Large scale DDoS attacks and simulated DDoS tests against VoIP providers
The birth of the cheapswap protocol
dhu编程练习
Realization of a springboard machine
2022年7月深圳地区CPDA数据分析师认证
论文回顾:Playful Palette: An Interactive Parametric Color Mixer for Artists
Illustration Google V8 19: asynchronous programming (II): how does V8 implement async/await?
Select sort
DHU programming exercise
有流量,但没有销售?增加网站销量的 6 个步骤
IBM websphere通道联通搭建和测试
隐藏在科技教育中的steam元素
2.< tag-动态规划和0-1背包问题>lt.416. 分割等和子集 + lt.1049. 最后一块石头的重量 II
dhu编程练习
【自然语言处理】【多模态】OFA:通过简单的sequence-to-sequence学习框架统一架构、任务和模态
快速排序
CA数字证书包含哪些文件?如何查看SSL证书信息?