W3cubDocs

/TensorFlow Python

tf.contrib.estimator.boosted_trees_classifier_train_in_memory

tf.contrib.estimator.boosted_trees_classifier_train_in_memory(
    train_input_fn,
    feature_columns,
    model_dir=None,
    n_classes=canned_boosted_trees._HOLD_FOR_MULTI_CLASS_SUPPORT,
    weight_column=None,
    label_vocabulary=None,
    n_trees=100,
    max_depth=6,
    learning_rate=0.1,
    l1_regularization=0.0,
    l2_regularization=0.0,
    tree_complexity=0.0,
    min_node_weight=0.0,
    config=None,
    train_hooks=None
)

Defined in tensorflow/contrib/estimator/python/estimator/boosted_trees.py.

Trains a boosted tree classifier with in memory dataset.

Example:

bucketized_feature_1 = bucketized_column(
  numeric_column('feature_1'), BUCKET_BOUNDARIES_1)
bucketized_feature_2 = bucketized_column(
  numeric_column('feature_2'), BUCKET_BOUNDARIES_2)

def train_input_fn():
  dataset = create-dataset-from-training-data
  # This is tf.data.Dataset of a tuple of feature dict and label.
  #   e.g. Dataset.zip((Dataset.from_tensors({'f1': f1_array, ...}),
  #                     Dataset.from_tensors(label_array)))
  # The returned Dataset shouldn't be batched.
  # If Dataset repeats, only the first repetition would be used for training.
  return dataset

classifier = boosted_trees_classifier_train_in_memory(
    train_input_fn,
    feature_columns=[bucketized_feature_1, bucketized_feature_2],
    n_trees=100,
    ... <some other params>
)

def input_fn_eval():
  ...
  return dataset

metrics = classifier.evaluate(input_fn=input_fn_eval, steps=10)

Args:

  • train_input_fn: the input function returns a dataset containing a single epoch of unbatched features and labels.
  • feature_columns: An iterable containing all the feature columns used by the model. All items in the set should be instances of classes derived from FeatureColumn.
  • model_dir: Directory to save model parameters, graph and etc. This can also be used to load checkpoints from the directory into a estimator to continue training a previously saved model.
  • n_classes: number of label classes. Default is binary classification. Multiclass support is not yet implemented.
  • weight_column: A string or a _NumericColumn created by tf.feature_column.numeric_column defining feature column representing weights. It is used to downweight or boost examples during training. It will be multiplied by the loss of the example. If it is a string, it is used as a key to fetch weight tensor from the features. If it is a _NumericColumn, raw tensor is fetched by key weight_column.key, then weight_column.normalizer_fn is applied on it to get weight tensor.
  • label_vocabulary: A list of strings represents possible label values. If given, labels must be string type and have any value in label_vocabulary. If it is not given, that means labels are already encoded as integer or float within [0, 1] for n_classes=2 and encoded as integer values in {0, 1,..., n_classes-1} for n_classes>2 . Also there will be errors if vocabulary is not provided and labels are string.
  • n_trees: number trees to be created.
  • max_depth: maximum depth of the tree to grow.
  • learning_rate: shrinkage parameter to be used when a tree added to the model.
  • l1_regularization: regularization multiplier applied to the absolute weights of the tree leafs.
  • l2_regularization: regularization multiplier applied to the square weights of the tree leafs.
  • tree_complexity: regularization factor to penalize trees with more leaves.
  • min_node_weight: minimum hessian a node must have for a split to be considered. The value will be compared with sum(leaf_hessian)/ (batch_size * n_batches_per_layer).
  • config: RunConfig object to configure the runtime settings.
  • train_hooks: a list of Hook instances to be passed to estimator.train().

Returns:

a BoostedTreesClassifier instance created with the given arguments and trained with the data loaded up on memory from the input_fn.

Raises:

  • ValueError: when wrong arguments are given or unsupported functionalities are requested.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/estimator/boosted_trees_classifier_train_in_memory