W3cubDocs

/PyTorch 2.9

torch.gather

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Gathers values along an axis specified by dim.

For a 3-D tensor the output is specified by:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

input and index must have the same number of dimensions. It is also required that index.size(d) <= input.size(d) for all dimensions d != dim. out will have the same shape as index. Note that input and index do not broadcast against each other. When index is empty, we always return an empty output with the same shape without further error checking.

Parameters
  • input (Tensor) – the source tensor
  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to gather
Keyword Arguments
  • sparse_grad (bool, optional) – If True, gradient w.r.t. input will be a sparse tensor.
  • out (Tensor, optional) – the destination tensor

Example:

>>> t = torch.tensor([[1, 2], [3, 4]])
>>> torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]]))
tensor([[ 1,  1],
        [ 4,  3]])

© 2025, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://docs.pytorch.org/docs/2.9/generated/torch.gather.html