Object that returns a tf.data.Dataset upon invoking.
tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=None
)
  tf.keras.utils.experimental.DatasetCreator is designated as a supported type for x, or the input, in tf.keras.Model.fit. Pass an instance of this class to fit when using a callable (with a input_context argument) that returns a tf.data.Dataset.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset
input_options = tf.distribute.InputOptions(
    experimental_fetch_to_device=True,
    experimental_per_replica_buffer_size=2)
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
 Model.fit usage with DatasetCreator is intended to work across all tf.distribute.Strategys, as long as Strategy.scope is used at model creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
  ...
input_options = ...
model.fit(tf.keras.utils.experimental.DatasetCreator(
    dataset_fn, input_options=input_options), epochs=10, steps_per_epoch=10)
Note: When usingDatasetCreator,steps_per_epochargument inModel.fitmust be provided as the cardinality of such input cannot be inferred.
| Args | |
|---|---|
| dataset_fn | A callable that takes a single argument of type tf.distribute.InputContext, which is used for batch size calculation and cross-worker input pipeline sharding (if neither is needed, theInputContextparameter can be ignored in thedataset_fn), and returns atf.data.Dataset. | 
| input_options | Optional tf.distribute.InputOptions, used for specific options when used with distribution, for example, whether to prefetch dataset elements to accelerator device memory or host device memory, and prefetch buffer size in the replica device memory. No effect if not used with distributed training. Seetf.distribute.InputOptionsfor more information. | 
__call__
__call__(
    *args, **kwargs
)
 Call self as a function.
    © 2022 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 4.0.
Code samples licensed under the Apache 2.0 License.
    https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/keras/utils/experimental/DatasetCreator