Module for constructing seq2seq models and dynamic decoding. Builds on top of libraries in
This library is composed of two primary components:
Attention wrappers are
RNNCell objects that wrap other
RNNCell objects and implement attention. The form of attention is determined by a subclass of
tf.contrib.seq2seq.AttentionMechanism. These subclasses describe the form of attention (e.g. additive vs. multiplicative) to use when creating the wrapper. An instance of an
AttentionMechanism is constructed with a
memory tensor, from which lookup keys and values tensors are created.
memory tensor passed the attention mechanism's constructor is expected to be shaped
[batch_size, memory_max_time, memory_depth]; and often an additional
memory_sequence_length vector is accepted. If provided, the
memory tensors' rows are masked with zeros past their true sequence lengths.
Attention mechanisms also have a concept of depth, usually determined as a construction parameter
num_units. For some kinds of attention (like
BahdanauAttention), both queries and memory are projected to tensors of depth
num_units. For other kinds (like
num_units should match the depth of the queries; and the
memory tensor will be projected to this depth.
The basic attention wrapper is
tf.contrib.seq2seq.AttentionWrapper. This wrapper accepts an
RNNCell instance, an instance of
AttentionMechanism, and an attention depth parameter (
attention_size); as well as several optional arguments that allow one to customize intermediate calculations.
At each time step, the basic calculation performed by this wrapper is:
cell_inputs = concat([inputs, prev_state.attention], -1) cell_output, next_cell_state = cell(cell_inputs, prev_state.cell_state) score = attention_mechanism(cell_output) alignments = softmax(score) context = matmul(alignments, attention_mechanism.values) attention = tf.layers.Dense(attention_size)(concat([cell_output, context], 1)) next_state = AttentionWrapperState( cell_state=next_cell_state, attention=attention) output = attention return output, next_state
In practice, a number of the intermediate calculations are configurable. For example, the initial concatenation of
prev_state.attention can be replaced with another mixing function. The function
softmax can be replaced with alternative options when calculating
alignments from the
score. Finally, the outputs returned by the wrapper can be configured to be the value
cell_output instead of
The benefit of using a
AttentionWrapper is that it plays nicely with other wrappers and the dynamic decoder described below. For example, one can write:
cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:0") attention_mechanism = tf.contrib.seq2seq.LuongAttention(512, encoder_outputs) attn_cell = tf.contrib.seq2seq.AttentionWrapper( cell, attention_mechanism, attention_size=256) attn_cell = tf.contrib.rnn.DeviceWrapper(attn_cell, "/device:GPU:1") top_cell = tf.contrib.rnn.DeviceWrapper(LSTMCell(512), "/device:GPU:1") multi_cell = MultiRNNCell([attn_cell, top_cell])
multi_rnn cell will perform the bottom layer calculations on GPU 0; attention calculations will be performed on GPU 1 and immediately passed up to the top layer which is also calculated on GPU 1. The attention is also passed forward in time to the next time step and copied to GPU 0 for the next time step of
cell. (Note: This is just an example of use, not a suggested device partitioning strategy.)
cell = # instance of RNNCell if mode == "train": helper = tf.contrib.seq2seq.TrainingHelper( input=input_vectors, sequence_length=input_lengths) elif mode == "infer": helper = tf.contrib.seq2seq.GreedyEmbeddingHelper( embedding=embedding, start_tokens=tf.tile([GO_SYMBOL], [batch_size]), end_token=END_SYMBOL) decoder = tf.contrib.seq2seq.BasicDecoder( cell=cell, helper=helper, initial_state=cell.zero_state(batch_size, tf.float32)) outputs, _ = tf.contrib.seq2seq.dynamic_decode( decoder=decoder, output_time_major=False, impute_finished=True, maximum_iterations=20)
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.