当前位置:网站首页>Collation of several difficult methods in pytorch --gather & squeeze & unsqueeze
Collation of several difficult methods in pytorch --gather & squeeze & unsqueeze
2022-07-27 05:28:00 【Charleslc's blog】
gather
pytorch in gather Source form :torch.gather(input, dim, index, *, sparse_grad = False, out = None)
And then in pytorch In official documents , I wrote such an example , This example is three-dimensional
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
It's hard to understand at first , Do not know what that mean? , So I tried several examples
A one-dimensional :
>>> array1 = torch.tensor([1,2,3])
>>> torch.gather(array1, 0, torch.tensor([0,1]))
tensor([1, 2])
In the above example ,array1 The matrix form of is array1 = [1,2,3], By dimension 0 Value ( In the case of one dimension , At most 0), take [array1[0],array1[1]] As an output , That is to say [1,2]
A two-dimensional
>>> array2 = torch.tensor([[1,2,3],[4,5,6]])
>>> torch.gather(array2, 0, torch.tensor([[0, 1]]))
tensor([[1, 5]])
In the above two-dimensional example ,array2 In the form of array2 = [[1,2,3],[4,5,6]], By dimension 0 Value , take [array2[0][0], array2[1][1]] For output results , That is to say [1,5]
Look at the two examples above , We according to the torch.gather(input, dim, index, *, sparse_grad = False, out = None)
Look at the formula :
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
One thing is clear , Output size Is and index Of size It's the same .
For one-dimensional , hypothesis index The size is n, So the output is [input[index[0]], input[index[1]], … , input[index[n]]], That is, in our last example [1,2]
For two-dimensional , If dim by 0, hypothesis index The size is m*n, So the output is
[[input[index[0][0]][0], input[index[0][1]][1], … , input[index[0][n]][n],
[input[index[1][0]][0], input[index[1][1]][1], … , input[index[1][n]][n],
…
[input[index[m][0]][0], input[index[m][1]][1], …, input[index[m][n]][n]
So we output the above example as [1,5]
If dim by 1 Well , The same assumption index The size is m*n, So the output is :
[[input[0][index[0][0]], input[1][index[0][1]], … , input[n][index[0][n]],
[input[0][index[1][0]], input[1][index[1][1]], … , input[n][index[1][n]],
…
[input[0][index[m][0]], input[1][index[m][1]], …, input[n][index[m][n]],
Seeing here, does it make you understand the formula given by the official website a little ?
Let's look at a few more 3 The example of dimension
array = torch.tensor(np.arange(24)).view(2,3,4)
array
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]], dtype=torch.int32)
torch.gather(array, 0, torch.tensor([[[0,1],[1,1]]]))
The output is
[[[array[0][0][0],array[1][0][1],
[array[1][1][0],array[1][1][1]
]]
tensor([[[ 0, 13],
[16, 17]]], dtype=torch.int32)
See here is a little understanding , If you don't understand, it's clearer to try more often .
squeeze
torch.squeeze Medium function form torch.squeeze(input, dim = None, *, out = None) -> Tensor, Default dim Parameter is None,
The official website also describes its role :Returns a tensor with all the dimensions of input of size 1 removed.
Is to remove all size by 1 Dimensions , For example, enter a array, its shape by (1,2,3,1,2), So his output Of size by (2,3,2)
Let's look at an example :
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array).size()
torch.Size([2, 3, 2])
If the of a dimension size Not for 1, Then don't remove it .
There is another way to remove specific size by 1 Dimensions , How to write it torch.squeeze(array, dim)
for example :
>>> array = torch.zeros(1,2,3,1,2)
>>> torch.squeeze(array,0).size()
torch.Size([2, 3, 1, 2])
Here is the fourth dimension 1 It has not been removed .
Let's take a look at unsqueeze Method
unsqueeze
torch The method described in the official website torch.unsqueeze(input, dim)->Tensor
Role is Returns an insert in a specific dimension size by 1 Of tensor
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
This dim How much can it be ? The official also made an explanation A dim value within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.
That is to say ,dim The input of can only be in [-input.dim()-1, input.dim() + 1] Within the scope of . In the example above , The dimension is limited to [-2, 1] Between .
How to deal with negative numbers ? dim = dim + input.dim() + 1, in other words , If input -2, So it should output dim = 0,
Actually, from this formula , and list The selected elements in are similar ,
for example list = [1,2,3,4]; list[0]= 1, list[-1] = 4, amount to dim by -1 Is to insert in the highest dimension size by 1, And when dim by **-input.dim() - 1** Equivalent to in dimension 0 Insert size by 1.
Example
>>> x = torch.tensor([1, 2, 3, 4])
>>> torch.unsqueeze(x, -2).size()
torch.Size([1, 4])
>>> torch.unsqueeze(x, -1).size()
torch.Size([4, 1])
>>> torch.unsqueeze(x, 0).size()
torch.Size([1, 4])
See here is not to understand it ?
( Follow up with other methods )
边栏推荐
- Quoted popular explanation
- Domestic mainstream ERP software market
- JVM Part 1: memory and garbage collection -- runtime data area 4 - program counter
- cookie增删改查和异常
- Shell course summary
- SSM framework integration
- B1024 scientific counting method
- Flask登录实现
- JVM Part 1: memory and garbage collection part 14 -- garbage collector
- JVM Part 1: memory and garbage collection part 7 -- runtime data area heap
猜你喜欢

JVM part I: memory and garbage collection part II -- class loading subsystem

Bean's life cycle & dependency injection * dependency auto assembly

Integrate SSM

JVM上篇:内存与垃圾回收篇二--类加载子系统

JVM上篇:内存与垃圾回收篇三--运行时数据区-概述及线程

JVM上篇:内存与垃圾回收篇--运行时数据区四-程序计数器

SQL database → constraint → design → multi table query → transaction

JVM上篇:内存与垃圾回收篇十--运行时数据区-直接内存

JVM上篇:内存与垃圾回收篇五--运行时数据区-虚拟机栈

redis发布订阅模式
随机推荐
Integrate SSM
JVM part I: memory and garbage collection part II -- class loading subsystem
Scientific Computing Library -- Matplotlib
322 coin change of leetcode
2021 OWASP top 5: security configuration error
内部类与静态内部类区别及举例
JVM Part 1: memory and garbage collection part 10 - runtime data area - direct memory
B1023 组个最小数
The receiver sets the concurrency and current limit
B1030 perfect sequence
redis事务
Explore the mysteries of the security, intelligence and performance of the universal altek platform!
JVM上篇:内存与垃圾回收篇--运行时数据区四-程序计数器
Machine learning overview
Quoted popular explanation
Differences and examples between internal classes and static internal classes
JVM Part 1: memory and garbage collection -- runtime data area 4 - program counter
torch中乘法整理,*&torch.mul()&torch.mv()&torch.mm()&torch.dot()&@&torch.mutmal()
JWT认证及登录功能实现,退出登录
Flask对模型类的操作