W3cubDocs

/TensorFlow Python

tf.contrib.seq2seq.ScheduledEmbeddingTrainingHelper

Class ScheduledEmbeddingTrainingHelper

Inherits From: TrainingHelper

Defined in tensorflow/contrib/seq2seq/python/ops/helper.py.

See the guide: Seq2seq Library (contrib) > Dynamic Decoding

A training helper that adds scheduled sampling.

Returns -1s for sample_ids where no sampling took place; valid sample id values elsewhere.

Properties

batch_size

Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

inputs

sample_ids_dtype

DType of tensor returned by sample.

Returns a DType.

sample_ids_shape

Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

sequence_length

Methods

__init__

__init__(
    inputs,
    sequence_length,
    embedding,
    sampling_probability,
    time_major=False,
    seed=None,
    scheduling_seed=None,
    name=None
)

Initializer.

Args:

  • inputs: A (structure of) input tensors.
  • sequence_length: An int32 vector tensor.
  • embedding: A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup.
  • sampling_probability: A 0D float32 tensor: the probability of sampling categorically from the output ids instead of reading directly from the inputs.
  • time_major: Python bool. Whether the tensors in inputs are time major. If False (default), they are assumed to be batch major.
  • seed: The sampling seed.
  • scheduling_seed: The schedule decision rule sampling seed.
  • name: Name scope for any created operations.

Raises:

  • ValueError: if sampling_probability is not a scalar or vector.

initialize

initialize(name=None)

Returns (initial_finished, initial_inputs).

next_inputs

next_inputs(
    time,
    outputs,
    state,
    sample_ids,
    name=None
)

next_inputs_fn for TrainingHelper.

sample

sample(
    time,
    outputs,
    state,
    name=None
)

Returns sample_ids.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledEmbeddingTrainingHelper