当前位置:网站首页>PyTorch磨刀篇|argmax和argmin函数
PyTorch磨刀篇|argmax和argmin函数
2022-07-01 21:44:00 【51CTO】
一、语法格式
格式一(只针对argmax函数):
torch.argmax(input) → LongTensor
功能:
Returns the indices of the maximum value of all elements in the input tensor。
即:返回输入张量中所有元素中最大值对应的索引(按行搜索);如果有多个相同的值,则返回第一次遇到的那个值对应的索引。
举例:
In [28]: r=torch.tensor([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15]])
In [29]: torch.argmax(r)
Out[29]: tensor(14)
格式二:
[1]torch.argmax(input, dim=None, keepdim=False)
功能:
Returns the indices of the maximum values of a tensor across a dimension.
- input( Tensor) – the input tensor.即:输出张量。
- dim( int) – the dimension to reduce. If
None
, the argmax of the flattened input is returned.即:要减少的维数。
- keepdim( bool) – whether the output tensor has
dim
retained or not. Ignored if dim=None
.即:
举例:
In [30]: a = torch.randn(4, 4)
In [31]: a
Out[31]:
tensor([[ 1.4360, 0.6342, -0.5233, 0.4902],
[ 1.1998, -0.8644, 0.5244, 0.2690],
[ 0.0998, -1.5043, 0.1619, -1.4634],
[ 0.0992, -1.0843, -1.3829, 0.5790]])
In [32]: torch.argmax(a)
Out[32]: tensor(0)
In [33]: torch.argmax(a,dim=0)
Out[33]: tensor([0, 0, 1, 3])
In [34]: torch.argmax(a,dim=1)
Out[34]: tensor([0, 0, 2, 3])
- 1.
- 2.
- 3.
- 4.
- 5.
- 6.
- 7.
- 8.
- 9.
- 10.
- 11.
- 12.
- 13.
- 14.
- 15.
- 16.
- 17.
- 对于tensor(0)输出,意义如下:
第0个: 1.4360 | 第1个: 0.6342 | 第2个: -0.5233 | 第3个: 0.4902 | 第4个: 1.1998 | 第5个: -0.8644 | 第6个: 0.5244 | 第7个: 0.2690 | 第8个: 0.0998 | 第9个: -1.5043 |
第10个: 0.1619 | 第11个: -1.4634 | 第12个: 0.0992 | 第13个: -1.0843 | 第14个: -1.3829 | 第15个: 0.5790 |
- 对于tensor([0, 0, 1, 3])输出,意义如下:
这时,每一列视为下标从0到3的一个数组。易见,从左到右每一列(数组)中最大值分别为:1.4360、0.6342、0.5244、0.5790,它们对应的一维数组中的下标分别为0、0、1、3,于是得到张量tensor([0, 0, 1, 3])。
- 对于tensor([0, 0, 2, 3])输出:
意义就容易理解了。沿水平方向从左向右从上到下看,每一行对应一个数组,下标向左向右依次为0、1、2、3。于是,这4个数组中最大值分别为1.4360、1.1998、0.1619、1.3829,它们对应的一维数组中的下标分别为0、0、2、3,于是得到张量tensor([0, 0, 2, 3])。
功能:
[2]torch.argmin(input, dim=None, keepdim=False) → LongTensor
argmin功能:Returns the indices of the minimum value(s) of the flattened tensor or along a dimension。
理解类似上面argmax函数的第二种格式,相应于dim=0和dim=1,依次返回由最小值对应下标组成的列方向数组与行方向数组组成的张量。
边栏推荐
- 【STM32】STM32CubeMX教程二–基本使用(新建工程点亮LED灯)
- pytest合集(2)— pytest運行方式
- 100年仅6款产品获批,疫苗竞争背后的“佐剂”江湖
- Fundamentals - IO intensive computing and CPU intensive computing
- Show member variables and methods in classes in idea
- Design and practice of new generation cloud native database
- 【juc学习之路第9天】屏障衍生工具
- String类型转换BigDecimal、Date类型
- Go - exe corresponding to related dependency
- Icml2022 | interventional contrastive learning based on meta semantic regularization
猜你喜欢
Talking from mlperf: how to lead the next wave of AI accelerator
[NOIP2013]积木大赛 [NOIP2018]道路铺设 贪心/差分
AirServer2022最新版功能介绍及下载
I received a letter from CTO inviting me to interview machine learning engineer
微软、哥伦比亚大学|GODEL:目标导向对话的大规模预训练
Introduction and download of the latest version of airserver2022
[noip2013] building block competition [noip2018] road laying greed / difference
Business visualization - make your flowchart'run'up
Copy ‘XXXX‘ to effectively final temp variable
Go - exe corresponding to related dependency
随机推荐
《QTreeView+QAbstractItemModel自定义模型》:系列教程之三[通俗易懂]
物联网rfid等
[NOIP2013]积木大赛 [NOIP2018]道路铺设 贪心/差分
Sonic云真机学习总结6 - 1.4.1服务端、agent端部署
焱融看 | 混合云时代下,如何制定多云策略
ICML2022 | 基于元语义正则化的介入性对比学习
Go — 相关依赖对应的exe
比较版本号[双指针截取自己想要的字串]
pytest合集(2)— pytest运行方式
中通笔试题:翻转字符串,例如abcd打印出dcba
游览器打开摄像头案例
plantuml介绍与使用
统计字符中每个字符出现的个数
Communication between browser tab pages
Can you get a raise? Analysis on gold content of PMP certificate
【生态伙伴】鲲鹏系统工程师培训
Test cancellation 1
100年仅6款产品获批,疫苗竞争背后的“佐剂”江湖
I received a letter from CTO inviting me to interview machine learning engineer
基础—io密集型计算和cpu密集型计算