View source on GitHub |
A transformation that batches ragged elements into tf.SparseTensor
s.
tf.data.experimental.dense_to_sparse_batch( batch_size, row_shape )
Like Dataset.padded_batch()
, this transformation combines multiple consecutive elements of the dataset, which might have different shapes, into a single element. The resulting element has three components (indices
, values
, and dense_shape
), which comprise a tf.SparseTensor
that represents the same data. The row_shape
represents the dense shape of each row in the resulting tf.SparseTensor
, to which the effective batch size is prepended. For example:
# NOTE: The following examples use `{ ... }` to represent the # contents of a dataset. a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] } a.apply(tf.data.experimental.dense_to_sparse_batch( batch_size=2, row_shape=[6])) == { ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices ['a', 'b', 'c', 'a', 'b'], # values [2, 6]), # dense_shape ([[0, 0], [0, 1], [0, 2], [0, 3]], ['a', 'b', 'c', 'd'], [1, 6]) }
Args | |
---|---|
batch_size | A tf.int64 scalar tf.Tensor , representing the number of consecutive elements of this dataset to combine in a single batch. |
row_shape | A tf.TensorShape or tf.int64 vector tensor-like object representing the equivalent dense shape of a row in the resulting tf.SparseTensor . Each element of this dataset must have the same rank as row_shape , and must have size less than or equal to row_shape in each dimension. |
Returns | |
---|---|
A Dataset transformation function, which can be passed to tf.data.Dataset.apply . |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/data/experimental/dense_to_sparse_batch