W3cubDocs

/TensorFlow Python

tf.contrib.seq2seq.SampleEmbeddingHelper

Class SampleEmbeddingHelper

Inherits From: GreedyEmbeddingHelper

Defined in tensorflow/contrib/seq2seq/python/ops/helper.py.

A helper for use during inference.

Uses sampling (from a distribution) instead of argmax and passes the result through an embedding layer to get the next input.

Properties

batch_size

Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

sample_ids_dtype

DType of tensor returned by sample.

Returns a DType.

sample_ids_shape

Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

Methods

__init__

__init__(
    embedding,
    start_tokens,
    end_token,
    softmax_temperature=None,
    seed=None
)

Initializer.

Args:

  • embedding: A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input.
  • start_tokens: int32 vector shaped [batch_size], the start tokens.
  • end_token: int32 scalar, the token that marks end of decoding.
  • softmax_temperature: (Optional) float32 scalar, value to divide the logits by before computing the softmax. Larger values (above 1.0) result in more random samples, while smaller values push the sampling distribution towards the argmax. Must be strictly greater than 0. Defaults to 1.0.
  • seed: (Optional) The sampling seed.

Raises:

  • ValueError: if start_tokens is not a 1D tensor or end_token is not a scalar.

initialize

initialize(name=None)

Returns (initial_finished, initial_inputs).

next_inputs

next_inputs(
    time,
    outputs,
    state,
    sample_ids,
    name=None
)

next_inputs_fn for GreedyEmbeddingHelper.

sample

sample(
    time,
    outputs,
    state,
    name=None
)

sample for SampleEmbeddingHelper.

© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/seq2seq/SampleEmbeddingHelper