Splits input tensor across all dimensions.
tf.raw_ops.XlaSplitND(
input, N, num_splits, paddings=[], name=None
)
An op which slices the input tensor based on the given num_splits attribute, pads slices optionally, and returned the slices. Slices are returned in row-major order.
This op may be generated via the TPU bridge.
For example, with input tensor:
[[0, 1, 2], [3, 4, 5], [6, 7, 8]]
num_splits:
[2, 2]
and paddings:
[1, 1]
the expected outputs is:
[[0, 1], [3, 4]] [[2, 0], [5, 0]] [[6, 7], [0, 0]] [[8, 0], [0, 0]]
| Args | |
|---|---|
input | A Tensor. Input tensor to split across all dimensions. } out_arg { name: "outputs" description: <
|
N | An int that is >= 1. |
num_splits | A list of ints. Number of ways to split per dimension. Shape dimensions must be evenly divisible. |
paddings | An optional list of ints. Defaults to []. Optional list of right paddings per dimension of input tensor to apply before splitting. This can be used to make a dimension evenly divisible. |
name | A name for the operation (optional). |
| Returns | |
|---|---|
A list of N Tensor objects with the same type as input. |
© 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/raw_ops/XlaSplitND