W3cubDocs

/PyTorch 2.9

torch.segment_reduce

torch.segment_reduce(data: Tensor, reduce: str, *, lengths: Tensor | None = None, indices: Tensor | None = None, offsets: Tensor | None = None, axis: _int = 0, unsafe: _bool = False, initial: Number | _complex | None = None) → Tensor

Perform a segment reduction operation on the input tensor along the specified axis.

Parameters
  • data (Tensor) – The input tensor on which the segment reduction operation will be performed.
  • reduce (str) – The type of reduction operation. Supported values are sum, mean, max, min, prod.
Keyword Arguments
  • lengths (Tensor, optional) – Length of each segment. Default: None.
  • offsets (Tensor, optional) – Offset of each segment. Default: None.
  • axis (int, optional) – The axis perform reduction. Default: 0.
  • unsafe (bool, optional) – Skip validation If True. Default: False.
  • initial (Number, optional) – The initial value for the reduction operation. Default: None.

Example:

>>> data = torch.tensor([[1, 2, 3, 4],[5, 6, 7, 8],[9, 10, 11, 12]], dtype=torch.float32, device='cuda')
>>> lengths = torch.tensor([2, 1], device='cuda')
>>> torch.segment_reduce(data, 'max', lengths=lengths)
tensor([[ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]], device='cuda:0')

© 2025, PyTorch Contributors
PyTorch has a BSD-style license, as found in the LICENSE file.
https://docs.pytorch.org/docs/2.9/generated/torch.segment_reduce.html