![]() |
Class AttentionWrapper
Wraps another RNNCell
with attention.
Inherits From: RNNCell
__init__
__init__(
cell,
attention_mechanism,
attention_layer_size=None,
alignment_history=False,
cell_input_fn=None,
output_attention=True,
initial_cell_state=None,
name=None,
attention_layer=None,
attention_fn=None,
dtype=None
)
Construct the AttentionWrapper
.
NOTE If you are using the BeamSearchDecoder
with a cell wrapped in
AttentionWrapper
, then you must ensure that:
- The encoder output has been tiled to
beam_width
viatf.contrib.seq2seq.tile_batch
(NOTtf.tile
). - The
batch_size
argument passed to thezero_state
method of this wrapper is equal totrue_batch_size * beam_width
. - The initial state created with
zero_state
above contains acell_state
value containing properly tiled final state from the encoder.
An example:
tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
encoder_outputs, multiplier=beam_width)
tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
encoder_final_state, multiplier=beam_width)
tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
sequence_length, multiplier=beam_width)
attention_mechanism = MyFavoriteAttentionMechanism(
num_units=attention_depth,
memory=tiled_inputs,
memory_sequence_length=tiled_sequence_length)
attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
decoder_initial_state = attention_cell.zero_state(
dtype, batch_size=true_batch_size * beam_width)
decoder_initial_state = decoder_initial_state.clone(
cell_state=tiled_encoder_final_state)
Args:
cell
: An instance ofRNNCell
.attention_mechanism
: A list ofAttentionMechanism
instances or a single instance.attention_layer_size
: A list of Python integers or a single Python integer, the depth of the attention (output) layer(s). If None (default), use the context as attention at each time step. Otherwise, feed the context and cell output into the attention layer to generate attention at each time step. If attention_mechanism is a list, attention_layer_size must be a list of the same length. If attention_layer is set, this must be None. If attention_fn is set, it must guaranteed that the outputs of attention_fn also meet the above requirements.alignment_history
: Python boolean, whether to store alignment history from all time steps in the final output state (currently stored as a time majorTensorArray
on which you must callstack()
).cell_input_fn
: (optional) Acallable
. The default is:lambda inputs, attention: array_ops.concat([inputs, attention], -1)
.output_attention
: Python bool. IfTrue
(default), the output at each time step is the attention value. This is the behavior of Luong-style attention mechanisms. IfFalse
, the output at each time step is the output ofcell
. This is the behavior of Bhadanau-style attention mechanisms. In both cases, theattention
tensor is propagated to the next time step via the state and is used there. This flag only controls whether the attention mechanism is propagated up to the next cell in an RNN stack or to the top RNN output.initial_cell_state
: The initial state value to use for the cell when the user callszero_state()
. Note that if this value is provided now, and the user uses abatch_size
argument ofzero_state
which does not match the batch size ofinitial_cell_state
, proper behavior is not guaranteed.name
: Name to use when creating ops.attention_layer
: A list oftf.compat.v1.layers.Layer
instances or a singletf.compat.v1.layers.Layer
instance taking the context and cell output as inputs to generate attention at each time step. If None (default), use the context as attention at each time step. If attention_mechanism is a list, attention_layer must be a list of the same length. If attention_layers_size is set, this must be None.attention_fn
: An optional callable function that allows users to provide their own customized attention function, which takes input (attention_mechanism, cell_output, attention_state, attention_layer) and outputs (attention, alignments, next_attention_state). If provided, the attention_layer_size should be the size of the outputs of attention_fn.dtype
: The cell dtype
Raises:
TypeError
:attention_layer_size
is not None and (attention_mechanism
is a list butattention_layer_size
is not; or vice versa).ValueError
: ifattention_layer_size
is not None,attention_mechanism
is a list, and its length does not match that ofattention_layer_size
; ifattention_layer_size
andattention_layer
are set simultaneously.
Properties
graph
DEPRECATED FUNCTION
output_size
Integer or TensorShape: size of outputs produced by this cell.
scope_name
state_size
The state_size
property of AttentionWrapper
.
Returns:
An AttentionWrapperState
tuple containing shapes used by this object.
Methods
tf.contrib.seq2seq.AttentionWrapper.get_initial_state
get_initial_state(
inputs=None,
batch_size=None,
dtype=None
)
tf.contrib.seq2seq.AttentionWrapper.zero_state
zero_state(
batch_size,
dtype
)
Return an initial (zero) state tuple for this AttentionWrapper
.
NOTE Please see the initializer documentation for details of how
to call zero_state
if using an AttentionWrapper
with a
BeamSearchDecoder
.
Args:
batch_size
:0D
integer tensor: the batch size.dtype
: The internal state data type.
Returns:
An AttentionWrapperState
tuple containing zeroed out tensors and,
possibly, empty TensorArray
objects.
Raises:
ValueError
: (or, possibly at runtime, InvalidArgument), ifbatch_size
does not match the output size of the encoder passed to the wrapper object at initialization time.