ScheduledOutputTrainingHelper
Inherits From: TrainingHelper
Defined in tensorflow/contrib/seq2seq/python/ops/helper.py
.
See the guide: Seq2seq Library (contrib) > Dynamic Decoding
A training helper that adds scheduled sampling directly to outputs.
Returns False for sample_ids where no sampling took place; True elsewhere.
batch_size
Batch size of tensor returned by sample
.
Returns a scalar int32 tensor.
inputs
sample_ids_dtype
DType of tensor returned by sample
.
Returns a DType.
sample_ids_shape
Shape of tensor returned by sample
, excluding the batch dimension.
Returns a TensorShape
.
sequence_length
__init__
__init__( inputs, sequence_length, sampling_probability, time_major=False, seed=None, next_inputs_fn=None, auxiliary_inputs=None, name=None )
Initializer.
inputs
: A (structure) of input tensors.sequence_length
: An int32 vector tensor.sampling_probability
: A 0D float32
tensor: the probability of sampling from the outputs instead of reading directly from the inputs.time_major
: Python bool. Whether the tensors in inputs
are time major. If False
(default), they are assumed to be batch major.seed
: The sampling seed.next_inputs_fn
: (Optional) callable to apply to the RNN outputs to create the next input when sampling. If None
(default), the RNN outputs will be used as the next inputs.auxiliary_inputs
: An optional (structure of) auxiliary input tensors with a shape that matches inputs
in all but (potentially) the final dimension. These tensors will be concatenated to the sampled output or the inputs
when not sampling for use as the next input.name
: Name scope for any created operations.ValueError
: if sampling_probability
is not a scalar or vector.initialize
initialize(name=None)
Returns (initial_finished, initial_inputs)
.
next_inputs
next_inputs( time, outputs, state, sample_ids, name=None )
next_inputs_fn for TrainingHelper.
sample
sample( time, outputs, state, name=None )
Returns sample_ids
.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/ScheduledOutputTrainingHelper