当前位置:网站首页>pytorch中几个难理解的方法整理--gather&squeeze&unsqueeze
pytorch中几个难理解的方法整理--gather&squeeze&unsqueeze
2022-07-27 05:01:00 【CharlesLC的博客】
gather
pytorch中gather源码形式:torch.gather(input, dim, index, *, sparse_grad = False, out = None)
然后在pytorch官方文档中,写了这样的一个例子,这个例子是三维的
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
刚开始比较难理解,不知道什么意思,于是试了几个例子
一维:
>>> array1 = torch.tensor([1,2,3])
>>> torch.gather(array1, 0, torch.tensor([0,1]))
tensor([1, 2])
上述例子中,array1的矩阵形式为array1 = [1,2,3], 按维度0取值(对于一维的情况,顶多也为0), 将[array1[0],array1[1]]作为输出结果,也就是[1,2]
二维
>>> array2 = torch.tensor([[1,2,3],[4,5,6]])
>>> torch.gather(array2, 0, torch.tensor([[0, 1]]))
tensor([[1, 5]])
在上述二维的例子中,array2的形式为array2 = [[1,2,3],[4,5,6]], 按维度0取值, 将[array2[0][0], array2[1][1]]为输出结果,也就是[1,5]
看了上面两个例子,我们根据torch.gather(input, dim, index, *, sparse_grad = False, out = None)
看下公式:
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
明确一点的是,输出的size是和index的size是一样的。
对于一维的,假设index的大小为n, 那么输出结果为 [input[index[0]], input[index[1]], … , input[index[n]]], 也就是我们上个例子中的[1,2]
对于二维的,如果dim为0,假设index的大小为m*n, 那么输出结果为
[[input[index[0][0]][0], input[index[0][1]][1], … , input[index[0][n]][n],
[input[index[1][0]][0], input[index[1][1]][1], … , input[index[1][n]][n],
…
[input[index[m][0]][0], input[index[m][1]][1], …, input[index[m][n]][n]
所以上面的例子我们输出为[1,5]
如果dim为1呢, 同样假设index的大小为m*n,那么输出结果为:
[[input[0][index[0][0]], input[1][index[0][1]], … , input[n][index[0][n]],
[input[0][index[1][0]], input[1][index[1][1]], … , input[n][index[1][n]],
…
[input[0][index[m][0]], input[1][index[m][1]], …, input[n][index[m][n]],
看到这里那是不是就对官网给出的公式有点理解了呢?
再来看几个3维的例子
array = torch.tensor(np.arange(24)).view(2,3,4)
array
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
torch.gather(array, 0, torch.tensor([[[0,1],[1,1]]]))
输出结果为
[[[array[0][0][0],array[1][0][1],
[array[1][1][0],array[1][1][1]
]]
tensor([[[ 0, 13],
[16, 17]]], dtype=torch.int32)
看到这里是不是有点理解了,不理解的平时多试试就比较清楚了。
squeeze
torch.squeeze中函数形式torch.squeeze(input, dim = None, *, out = None) -> Tensor,默认dim参数为None,
官网也描述了它的作用:Returns a tensor with all the dimensions of input of size 1 removed.
就是移除所有size为1的维度,比如说输入一个array,它的shape为(1,2,3,1,2),那么他的output的size为(2,3,2)
具体看一下例子:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array).size()
torch.Size([2, 3, 2])
如果某个维度的size不为1,那就不移除。
另外还有一种写法可以移除特定size为1的维度,写法torch.squeeze(array, dim)
例如:
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array,0).size()
torch.Size([2, 3, 1, 2])
这里第四维度的1就没有移除掉。
下面我们再来看下unsqueeze方法
unsqueeze
torch官网中描述的方法 torch.unsqueeze(input, dim)->Tensor
作用是返回一个在特定维度插入size为1的tensor
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
这个dim可以为多少呢?官方也是做出了解释A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
就是说,dim的输入只能在[-input.dim()-1, input.dim() + 1]范围内。在上面的例子中,维度限制在[-2, 1]之间。
如果是负数怎么处理呢? dim = dim + input.dim() + 1, 也就是说,如果输入-2, 那么应该输出dim = 0,
其实从这个公式,和list中里面的选取元素差不多,
例如list = [1,2,3,4]; list[0]= 1, list[-1] = 4,相当于 dim为-1 就是在最高维插入size为1, 而当dim为**-input.dim() - 1**相当于在维度0处插入size为1。
例子
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, -2).size()
torch.Size([1, 4])
>>> torch.unsqueeze(x, -1).size()
torch.Size([4, 1])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
看到这里是不是就理解了呢?
(其他方法后续跟新)
边栏推荐
猜你喜欢

2021 OWASP top 5: security configuration error

JVM Part 1: memory and garbage collection part 9 - runtime data area - object instantiation, memory layout and access location

JVM Part 1: memory and garbage collection part 12 -- stringtable

Acticiti中startProcessInstanceByKey方法在variable表中的如何存储

JVM上篇:内存与垃圾回收篇七--运行时数据区-堆

来自“飞人”乔丹的启示!奥尼尔开启的另一个“赛场”

JVM Part 1: memory and garbage collection part 6 -- runtime data area local method & local method stack

File processing (IO)

Gradio quickly builds ml/dl Web Services

笔记系列之docker安装Postgresql 14
随机推荐
B1025 反转链表*******
The interface can automatically generate E and other asynchronous access or restart,
《Robust and Precise Vehicle Localization based on Multi-sensor Fusionin Diverse City Scenes》翻译
探寻通用奥特能平台安全、智能、性能的奥秘!
JVM Part 1: memory and garbage collection part 10 - runtime data area - direct memory
35.滚动 scroll
文件处理(IO)
The provision of operation and maintenance manager is significantly affected, and, for example, it is like an eep command
Basic operation of vim
B1028 人口普查
Database design - relational data theory (ultra detailed)
A math problem cost the chip giant $500million
Shell course summary
树莓派rtmp推流本地摄像头图像
2021 OWASP top 4: unsafe design
[optical flow] - data format analysis, flowwarp visualization
Detailed description of binary search tree
How idea creates a groovy project (explain in detail with pictures and texts)
JVM上篇:内存与垃圾回收篇三--运行时数据区-概述及线程
B1026 program running time