当前位置:网站首页>Gather function in pytorch_

Gather function in pytorch_

2022-06-12 20:58:00 Human high quality Algorithm Engineer

Let's first look at the explanation of the official documents :

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0

out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1

out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

If input is an n-dimensional tensor with size (x0,x1…,xi−1,xi,xi+1,…,xn−1)(x_0, x_1…, x_{i-1}, x_i, x_{i+1}, …, x_{n-1})(x0​,x1​…,xi−1​,xi​,xi+1​,…,xn−1​) and dim = i, then index must be an nnn -dimensional tensor with size (x0,x1,…,xi−1,y,xi+1,…,xn−1)(x_0, x_1, …, x_{i-1}, y, x_{i+1}, …, x_{n-1})(x0​,x1​,…,xi−1​,y,xi+1​,…,xn−1​) where y≥1y \geq 1y≥1 and out will have the same size as index.

Let's take an example :

import torch

b = torch.Tensor([[1, 2, 3], [4, 5, 6]])

index_1 = torch.LongTensor([[0, 1], [2, 0]])

print(torch.gather(b, dim=1, index=index_1))

Output

tensor([[1., 2.],

[6., 4.]])

Next, calculate the output of the result according to the document ,out[0][0] = input[0][index[0][0]] = input[0][0] = 1

out[0][1] = input[0][index[0][1]] = input[0][1] = 2

out[1][0] = input[1][index[1][0]] = input[1][2] = 6

out[1][1] = input[1][index[1][1]] = input[1][0] = 4

Reprinted from :https://blog.csdn.net/weixin_39757893/article/details/112926469

原网站

版权声明
本文为[Human high quality Algorithm Engineer]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/02/202202281434270398.html