Alan Lee

理解 PyTorch 中的 gather 函数

2021/08/21 Share

好久没更新博客了,最近一直在忙,既有生活上的也有工作上的。道阻且长啊。

今天来水一文,说一说最近工作上遇到的一个函数:torch.gather()

文字理解

我遇到的代码是 NLP 相关的,代码中用 torch.gather() 来将一个 tensor 的 shape 从 (batch_size, seq_length, hidden_size) 转为 (batch_size, labels_length, hidden_size) ,其中 seq_length >= labels_length

torch.gather() 的官方解释是

Gathers values along an axis specified by dim.

就是在指定维度上 gather value。那么怎么 gather、gather 哪些 value 呢?这就要看其参数了。

torch.gather() 的必填也是最常用的参数有三个,下面引用官方解释:

  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather

所以一句话概括 gather 操作就是:根据 index ,在 inputdim 维度上收集 value

具体来说,input 就是源 tensor,等会我们要在这个 tensor 上执行 gather 操作。如果 input 是一个一维数组,即 flat 列表,那么我们就可以直接根据 indexinput 上取了,就像正常的列表/数组索引一样。但是由于 input 可能含有多个维度,是 N 维数组,所以我们需要知道在哪个维度上进行 gather,这就是 dim 的作用。

对于 dim 参数,一种更为具体的理解方式是替换法。假设 inputindex 均为三维数组,那么输出 tensor 每个位置的索引是列表 [i, j, k] ,正常来说我们直接取 input[i, j, k] 作为 输出 tensor 对应位置的值即可,但是由于 dim 的存在以及 input.shape 可能不等于 index.shape ,所以直接取值可能就会报 IndexError 。所以我们是将索引列表的相应位置替换为 dim ,再去 input 取值。如果 dim=0 ,我们就替换索引列表第 0 个值,即 [dim, j, k] ,依此类推。Pytorch 的官方文档也是这么写的:

1
2
3
out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2

但是可能你还有点迷糊,没关系接着看下面的直观理解部分,然后再回来看这段话,结合着看,相信你很快能明白。

由于我们是按照 index 来取值的,所以最终得到的 tensor 的 shape 也是和 index 一样的,就像我们在列表上按索引取值,得到的输出列表长度和索引相等一样。

直观理解

为便于理解,我们以一个具体例子来说明。我们使用反推法,根据 input 和输出推参数。这应该也是我们平常自己写代码的时候遇到比较多的情况。

假设 input 和我们想要的输出 output 如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
>>> output_tensor # shape: (2, 2, 4)
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[20, 21, 22, 23]]])

即,我们想让 shape 为 (2, 3, 4)input_tensor 变成 shape 为 (2, 2, 4)output_tensor ,丢弃维度 1 的第 2 个元素,即 [ 4, 5, 6, 7][16, 17, 18, 19]

我们应用替换法,重点是找出来 dimindex 的值。始终记住 indexoutput_tensor 的 shape 是一样的。

output_tensor 的第一个位置开始,由于 output_tensor[0, 0, :] = input_tensor[0, 0, :] ,所以此时 [i, j, k] 是一样的,我们看不出来 dim 应该是多少。

下一行 output_tensor[0, 1, 0] = input_tensor[0, 2, 0] ,这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2, index_tensor[0, 1, 0]=2

此时 dim 已经明确。同理,output_tensor[0, 1, 1] = input_tensor[0, 2, 1]index_tensor[0, 1, 1]=2 ,依此类推,得到 index_tensor[0, 1, :] = 2 。同时也可以明确 index_tensor[0, 0, :] = 0

所以

1
2
3
4
5
6
7
>>> dim = 0
>>> index_tensor
tensor([[[0, 0, 0, 0],
[2, 2, 2, 2]],

[[0, 0, 0, 0],
[2, 2, 2, 2]]])

简单可描述如下图:

torch.gather 执行过程

为描述方便,假如我们把输入看作是 6 行,从上到下依次是 0-5。那么从事后诸葛亮的角度讲,输出相当于是把第 1 和第 4 行“抽掉”。如果输出和输入一样,那么原本的 index_tensor 就是如下:

1
2
3
4
5
6
7
tensor([[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]],

[[0, 0, 0, 0],
[1, 1, 1, 1],
[2, 2, 2, 2]]])

“抽掉”后, index_tensor 也相应“抽掉”,那么就得到我们想要的结果了。而且由于这个“抽掉”的操作是在维度 1 上进行的,那么 dim 自然是 1。

numpy.take()tf.gather 貌似也是同样功能,就不细说了。

Reference

END

CATALOG
  1. 1. 文字理解
  2. 2. 直观理解
  3. 3. Reference
  4. 4. END