当前位置:网站首页>pytorch中几个难理解的方法整理--gather&squeeze&unsqueeze
pytorch中几个难理解的方法整理--gather&squeeze&unsqueeze
2022-07-27 05:01:00 【CharlesLC的博客】
gather
pytorch中gather源码形式:torch.gather(input, dim, index, *, sparse_grad = False, out = None)
然后在pytorch官方文档中,写了这样的一个例子,这个例子是三维的
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
刚开始比较难理解,不知道什么意思,于是试了几个例子
一维:
>>> array1 = torch.tensor([1,2,3])
>>> torch.gather(array1, 0, torch.tensor([0,1]))
tensor([1, 2])
上述例子中,array1的矩阵形式为array1 = [1,2,3], 按维度0取值(对于一维的情况,顶多也为0), 将[array1[0],array1[1]]作为输出结果,也就是[1,2]
二维
>>> array2 = torch.tensor([[1,2,3],[4,5,6]])
>>> torch.gather(array2, 0, torch.tensor([[0, 1]]))
tensor([[1, 5]])
在上述二维的例子中,array2的形式为array2 = [[1,2,3],[4,5,6]], 按维度0取值, 将[array2[0][0], array2[1][1]]为输出结果,也就是[1,5]
看了上面两个例子,我们根据torch.gather(input, dim, index, *, sparse_grad = False, out = None)
看下公式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
明确一点的是,输出的size是和index的size是一样的。
对于一维的,假设index的大小为n, 那么输出结果为 [input[index[0]], input[index[1]], … , input[index[n]]], 也就是我们上个例子中的[1,2]
对于二维的,如果dim为0,假设index的大小为m*n, 那么输出结果为
[[input[index[0][0]][0], input[index[0][1]][1], … , input[index[0][n]][n],
[input[index[1][0]][0], input[index[1][1]][1], … , input[index[1][n]][n],
…
[input[index[m][0]][0], input[index[m][1]][1], …, input[index[m][n]][n]
所以上面的例子我们输出为[1,5]
如果dim为1呢, 同样假设index的大小为m*n,那么输出结果为:
[[input[0][index[0][0]], input[1][index[0][1]], … , input[n][index[0][n]],
[input[0][index[1][0]], input[1][index[1][1]], … , input[n][index[1][n]],
…
[input[0][index[m][0]], input[1][index[m][1]], …, input[n][index[m][n]],
看到这里那是不是就对官网给出的公式有点理解了呢?
再来看几个3维的例子
array = torch.tensor(np.arange(24)).view(2,3,4)
array
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
torch.gather(array, 0, torch.tensor([[[0,1],[1,1]]]))
输出结果为
[[[array[0][0][0],array[1][0][1],
[array[1][1][0],array[1][1][1]
]]
tensor([[[ 0, 13],
[16, 17]]], dtype=torch.int32)
看到这里是不是有点理解了,不理解的平时多试试就比较清楚了。
squeeze
torch.squeeze中函数形式torch.squeeze(input, dim = None, *, out = None) -> Tensor,默认dim参数为None,
官网也描述了它的作用:Returns a tensor with all the dimensions of input of size 1 removed.
就是移除所有size为1的维度,比如说输入一个array,它的shape为(1,2,3,1,2),那么他的output的size为(2,3,2)
具体看一下例子:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array).size()
torch.Size([2, 3, 2])
如果某个维度的size不为1,那就不移除。
另外还有一种写法可以移除特定size为1的维度,写法torch.squeeze(array, dim)
例如:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array,0).size()
torch.Size([2, 3, 1, 2])
这里第四维度的1就没有移除掉。
下面我们再来看下unsqueeze方法
unsqueeze
torch官网中描述的方法 torch.unsqueeze(input, dim)->Tensor
作用是返回一个在特定维度插入size为1的tensor
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
这个dim可以为多少呢?官方也是做出了解释A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
就是说,dim的输入只能在[-input.dim()-1, input.dim() + 1]范围内。在上面的例子中,维度限制在[-2, 1]之间。
如果是负数怎么处理呢? dim = dim + input.dim() + 1, 也就是说,如果输入-2, 那么应该输出dim = 0,
其实从这个公式,和list中里面的选取元素差不多,
例如list = [1,2,3,4]; list[0]= 1, list[-1] = 4,相当于 dim为-1 就是在最高维插入size为1, 而当dim为**-input.dim() - 1**相当于在维度0处插入size为1。
例子
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, -2).size()
torch.Size([1, 4])
>>> torch.unsqueeze(x, -1).size()
torch.Size([4, 1])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
看到这里是不是就理解了呢?
(其他方法后续跟新)
边栏推荐
- Scientific Computing Library - numpy
- 数据库设计——关系数据理论(超详细)
- 稀疏数组→五子棋的存盘续盘等操作
- JVM上篇:内存与垃圾回收篇十一--执行引擎
- OFDM 16 lecture 2-ofdm and the DFT
- [Niuke discussion area] Chapter 7: building safe and efficient enterprise services
- MQ message queue is used to design the high concurrency of the order placing process, the generation scenarios and solutions of message squeeze, message loss and message repetition
- 实用小工具: Kotlin 代码片段
- Quoted popular explanation
- B1027 打印沙漏
猜你喜欢

JVM上篇:内存与垃圾回收篇二--类加载子系统

树莓派rtmp推流本地摄像头图像

JDBC API 详解

JVM Part 1: memory and garbage collection part 5 -- runtime data area virtual machine stack

JVM上篇:内存与垃圾回收篇十一--执行引擎

Use ngrok for intranet penetration

JVM上篇:内存与垃圾回收篇六--运行时数据区-本地方法&本地方法栈

《Robust and Precise Vehicle Localization based on Multi-sensor Fusionin Diverse City Scenes》翻译

Raspberry pie RTMP streaming local camera image

Shell course summary
随机推荐
Bean的生命周期&&依赖注入*依赖自动装配
牛客剑指offer--JZ12 矩阵中的路径
B1023 组个最小数
JVM Part 1: memory and garbage collection part 8 - runtime data area - Method area
2021 OWASP top 5: security configuration error
Solution and principle analysis of feign call missing request header
JVM上篇:内存与垃圾回收篇六--运行时数据区-本地方法&本地方法栈
2022 Zhengzhou light industry Freshmen's competition topic - I won't say if I'm killed
Tcp server是如何一个端口处理多个客户端连接的(一对一还是一对多)
mq设置过期时间、优先级、死信队列、延迟队列
Set static IP for raspberry pie
mq常见问题
简化JDBC的MyBits框架
Scientific Computing Library -- Matplotlib
Install pyGame
Static and final keyword learning demo exercise
DBUtils
知识点总结(一)
枚举类实现单例模式
JVM上篇:内存与垃圾回收篇十一--执行引擎