/TensorFlow Guide

Datasets Quick Start

The tf.data module contains a collection of classes that allows you to easily load data, manipulate it, and pipe it into your model. This document introduces the API by walking through two simple examples:

  • Reading in-memory data from numpy arrays.
  • Reading lines from a csv file.

Basic input

Taking slices from an array is the simplest way to get started with tf.data.

The Premade Estimators chapter describes the following train_input_fn, from iris_data.py, to pipe the data into the Estimator:

def train_input_fn(features, labels, batch_size):
    """An input function for training"""
    # Convert the inputs to a Dataset.
    dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))

    # Shuffle, repeat, and batch the examples.
    dataset = dataset.shuffle(1000).repeat().batch(batch_size)

    # Return the dataset.
    return dataset

Let's look at this more closely.


This function expects three arguments. Arguments expecting an "array" can accept nearly anything that can be converted to an array with numpy.array. One exception is tuple which, as we will see, has special meaning for Datasets.

  • features: A {'feature_name':array} dictionary (or DataFrame) containing the raw input features.
  • labels : An array containing the label for each example.
  • batch_size : An integer indicating the desired batch size.

In premade_estimator.py we retrieved the Iris data using the iris_data.load_data() function. You can run it, and unpack the results as follows:

import iris_data

# Fetch the data
train, test = iris_data.load_data()
features, labels = train

Then we passed this data to the input function, with a line similar to this:

iris_data.train_input_fn(features, labels, batch_size)

Let's walk through the train_input_fn().


The function starts by using the tf.data.Dataset.from_tensor_slices function to create a tf.data.Dataset representing slices of the array. The array is sliced across the first dimension. For example, an array containing the mnist training data has a shape of (60000, 28, 28). Passing this to from_tensor_slices returns a Dataset object containing 60000 slices, each one a 28x28 image.

The code that returns this Dataset is as follows:

train, test = tf.keras.datasets.mnist.load_data()
mnist_x, mnist_y = train

mnist_ds = tf.data.Dataset.from_tensor_slices(mnist_x)

This will print the following line, showing the shapes and types of the items in the dataset. Note that a Dataset does not know how many items it contains.

<TensorSliceDataset shapes: (28,28), types: tf.uint8>

The Dataset above represents a simple collection of arrays, but datasets are much more powerful than this. A Dataset can transparently handle any nested combination of dictionaries or tuples (or namedtuple ).

For example after converting the iris features to a standard python dictionary, you can then convert the dictionary of arrays to a Dataset of dictionaries as follows:

dataset = tf.data.Dataset.from_tensor_slices(dict(features))

  shapes: {
    SepalLength: (), PetalWidth: (),
    PetalLength: (), SepalWidth: ()},

  types: {
      SepalLength: tf.float64, PetalWidth: tf.float64,
      PetalLength: tf.float64, SepalWidth: tf.float64}

Here we see that when a Dataset contains structured elements, the shapes and types of the Dataset take on the same structure. This dataset contains dictionaries of scalars, all of type tf.float64.

The first line of the iris train_input_fn uses the same functionality, but adds another level of structure. It creates a dataset containing (features_dict, label) pairs.

The following code shows that the label is a scalar with type int64:

# Convert the inputs to a Dataset.
dataset = tf.data.Dataset.from_tensor_slices((dict(features), labels))
    shapes: (
          SepalLength: (), PetalWidth: (),
          PetalLength: (), SepalWidth: ()},

    types: (
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},


Currently the Dataset would iterate over the data once, in a fixed order, and only produce a single element at a time. It needs further processing before it can be used for training. Fortunately, the tf.data.Dataset class provides methods to better prepare the data for training. The next line of the input function takes advantage of several of these methods:

# Shuffle, repeat, and batch the examples.
dataset = dataset.shuffle(1000).repeat().batch(batch_size)

The shuffle method uses a fixed-size buffer to shuffle the items as they pass through. In this case the buffer_size is greater than the number of examples in the Dataset, ensuring that the data is completely shuffled (The Iris data set only contains 150 examples).

The repeat method restarts the Dataset when it reaches the end. To limit the number of epochs, set the count argument.

The batch method collects a number of examples and stacks them, to create batches. This adds a dimension to their shape. The new dimension is added as the first dimension. The following code uses the batch method on the MNIST Dataset, from earlier. This results in a Dataset containing 3D arrays representing stacks of (28,28) images:

  shapes: (?, 28, 28),
  types: tf.uint8>

Note that the dataset has an unknown batch size because the last batch will have fewer elements.

In train_input_fn, after batching the Dataset contains 1D vectors of elements where each scalar was previously:

    shapes: (
          SepalLength: (?,), PetalWidth: (?,),
          PetalLength: (?,), SepalWidth: (?,)},

    types: (
          SepalLength: tf.float64, PetalWidth: tf.float64,
          PetalLength: tf.float64, SepalWidth: tf.float64},


At this point the Dataset contains (features_dict, labels) pairs. This is the format expected by the train and evaluate methods, so the input_fn returns the dataset.

The labels can/should be omitted when using the predict method.

Reading a CSV File

The most common real-world use case for the Dataset class is to stream data from files on disk. The tf.data module includes a variety of file readers. Let's see how parsing the Iris dataset from the csv file looks using a Dataset.

The following call to the iris_data.maybe_download function downloads the data if necessary, and returns the pathnames of the resulting files:

import iris_data
train_path, test_path = iris_data.maybe_download()

The iris_data.csv_input_fn function contains an alternative implementation that parses the csv files using a Dataset.

Let's look at how to build an Estimator-compatible input function that reads from the local files.

Build the Dataset

We start by building a TextLineDataset object to read the file one line at a time. Then, we call the skip method to skip over the first line of the file, which contains a header, not an example:

ds = tf.data.TextLineDataset(train_path).skip(1)

Build a csv line parser

We will start by building a function to parse a single line.

The following iris_data.parse_line function accomplishes this task using the tf.decode_csv function, and some simple python code:

We must parse each of the lines in the dataset in order to generate the necessary (features, label) pairs. The following _parse_line function calls tf.decode_csv to parse a single line into its features and the label. Since Estimators require that features be represented as a dictionary, we rely on Python's built-in dict and zip functions to build that dictionary. The feature names are the keys of that dictionary. We then call the dictionary's pop method to remove the label field from the features dictionary:

# Metadata describing the text columns
COLUMNS = ['SepalLength', 'SepalWidth',
           'PetalLength', 'PetalWidth',
FIELD_DEFAULTS = [[0.0], [0.0], [0.0], [0.0], [0]]
def _parse_line(line):
    # Decode the line into its fields
    fields = tf.decode_csv(line, FIELD_DEFAULTS)

    # Pack the result into a dictionary
    features = dict(zip(COLUMNS,fields))

    # Separate the label from the features
    label = features.pop('label')

    return features, label

Parse the lines

Datasets have many methods for manipulating the data while it is being piped to a model. The most heavily-used method is map, which applies a transformation to each element of the Dataset.

The map method takes a map_func argument that describes how each item in the Dataset should be transformed.

The map method applies the `map_func` to transform each item in the Dataset.

So to parse the lines as they are streamed out of the csv file, we pass our _parse_line function to the map method:

ds = ds.map(_parse_line)
shapes: (
    {SepalLength: (), PetalWidth: (), ...},
types: (
    {SepalLength: tf.float32, PetalWidth: tf.float32, ...},

Now instead of simple scalar strings, the dataset contains (features, label) pairs.

the remainder of the iris_data.csv_input_fn function is identical to iris_data.train_input_fn which was covered in the in the Basic input section.

Try it out

This function can be used as a replacement for iris_data.train_input_fn. It can be used to feed an estimator as follows:

train_path, test_path = iris_data.maybe_download()

# All the inputs are numeric
feature_columns = [
    for name in iris_data.CSV_COLUMN_NAMES[:-1]]

# Build the estimator
est = tf.estimator.LinearClassifier(feature_columns,
# Train the estimator
batch_size = 100
    input_fn=lambda : iris_data.csv_input_fn(train_path, batch_size))

Estimators expect an input_fn to take no arguments. To work around this restriction, we use lambda to capture the arguments and provide the expected interface.


The tf.data module provides a collection of classes and functions for easily reading data from a variety of sources. Furthermore, tf.data has simple powerful methods for applying a wide variety of standard and custom transformations.

Now you have the basic idea of how to efficiently load data into an Estimator. Consider the following documents next:

  • Creating Custom Estimators, which demonstrates how to build your own custom Estimator model.
  • The Low Level Introduction, which demonstrates how to experiment directly with tf.data.Datasets using TensorFlow's low level APIs.
  • Importing Data which goes into great detail about additional functionality of Datasets.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.