diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 275245fded..86f3bbcba0 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -138,6 +138,13 @@ def __init__(self, self.values = super(_BaseAttentionMechanism, self).__call__( inputs, setup_memory=True) + @property + def memory_initialized(self): + """Returns `True` if this attention mechanism has been initialized with + a memory. + """ + return self._memory_initialized + def build(self, input_shape): if not self._memory_initialized: # This is for setting up the memory, which contains memory and @@ -1680,7 +1687,6 @@ def __init__(self, use_bias=False, dtype=attention_mechanisms[i].dtype) for i, attention_layer_size in enumerate(attention_layer_sizes)) - self._attention_layer_size = sum(attention_layer_sizes) elif attention_layer is not None: self._attention_layers = list( attention_layer if isinstance(attention_layer, ( @@ -1690,22 +1696,13 @@ def __init__(self, "If provided, attention_layer must contain exactly one " "layer per attention_mechanism, saw: %d vs %d" % (len( self._attention_layers), len(attention_mechanisms))) - self._attention_layer_size = sum( - tf.compat.dimension_value( - layer.compute_output_shape([ - None, cell.output_size + - tf.compat.dimension_value(mechanism.values.shape[-1]) - ])[-1]) for layer, mechanism in zip( - self._attention_layers, attention_mechanisms)) else: self._attention_layers = None - self._attention_layer_size = sum( - tf.compat.dimension_value(attention_mechanism.values.shape[-1]) - for attention_mechanism in attention_mechanisms) if attention_fn is None: attention_fn = _compute_attention self._attention_fn = attention_fn + self._attention_layer_size = None self._cell = cell self._attention_mechanisms = attention_mechanisms @@ -1735,7 +1732,17 @@ def __init__(self, s, name="check_initial_cell_state"), initial_cell_state) + def _attention_mechanisms_checks(self): + for attention_mechanism in self._attention_mechanisms: + if not attention_mechanism.memory_initialized: + raise ValueError("The AttentionMechanism instances passed to " + "this AttentionWrapper should be initialized " + "with a memory first, either by passing it " + "to the AttentionMechanism constructor or " + "calling attention_mechanism.setup_memory()") + def _batch_size_checks(self, batch_size, error_message): + self._attention_mechanisms_checks() return [ tf.compat.v1.assert_equal( batch_size, @@ -1744,6 +1751,26 @@ def _batch_size_checks(self, batch_size, error_message): for attention_mechanism in self._attention_mechanisms ] + def _get_attention_layer_size(self): + if self._attention_layer_size is not None: + return self._attention_layer_size + self._attention_mechanisms_checks() + attention_output_sizes = ( + attention_mechanism.values.shape[-1] + for attention_mechanism in self._attention_mechanisms) + if self._attention_layers is None: + self._attention_layer_size = sum(attention_output_sizes) + else: + # Compute the layer output size from its input which is the + # concatenation of the cell output and the attention mechanism + # output. + self._attention_layer_size = sum( + layer.compute_output_shape( + [None, self._cell.output_size + attention_output_size])[-1] + for layer, attention_output_size in zip( + self._attention_layers, attention_output_sizes)) + return self._attention_layer_size + def _item_or_tuple(self, seq): """Returns `seq` as tuple or the singular element. @@ -1767,7 +1794,7 @@ def _item_or_tuple(self, seq): @property def output_size(self): if self._output_attention: - return self._attention_layer_size + return self._get_attention_layer_size() else: return self._cell.output_size @@ -1782,7 +1809,7 @@ def state_size(self): return AttentionWrapperState( cell_state=self._cell.state_size, time=tf.TensorShape([]), - attention=self._attention_layer_size, + attention=self._get_attention_layer_size(), alignments=self._item_or_tuple( a.alignments_size for a in self._attention_mechanisms), attention_state=self._item_or_tuple( @@ -1841,7 +1868,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None): return AttentionWrapperState( cell_state=cell_state, time=tf.zeros([], dtype=tf.int32), - attention=_zero_state_tensors(self._attention_layer_size, + attention=_zero_state_tensors(self._get_attention_layer_size(), batch_size, dtype), alignments=self._item_or_tuple(initial_alignments), attention_state=self._item_or_tuple( diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 422e56c9d7..099590d09a 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -252,6 +252,28 @@ def setUp(self): self.decoder_sequence_length = np.random.randint( self.decoder_timestep, size=(self.batch,)).astype(np.int32) + def testCustomAttentionLayer(self): + attention_mechanism = wrapper.LuongAttention(self.units) + cell = tf.keras.layers.LSTMCell(self.units) + attention_layer = tf.keras.layers.Dense( + self.units * 2, use_bias=False, activation=tf.math.tanh) + attention_wrapper = wrapper.AttentionWrapper( + cell, attention_mechanism, attention_layer=attention_layer) + with self.assertRaises(ValueError): + # Should fail because the attention mechanism has not been + # initialized. + attention_wrapper.get_initial_state( + batch_size=self.batch, dtype=tf.float32) + attention_mechanism.setup_memory( + self.encoder_outputs.astype(np.float32), + memory_sequence_length=self.encoder_sequence_length) + initial_state = attention_wrapper.get_initial_state( + batch_size=self.batch, dtype=tf.float32) + self.assertEqual(initial_state.attention.shape[-1], self.units * 2) + first_input = self.decoder_inputs[:, 0].astype(np.float32) + output, next_state = attention_wrapper(first_input, initial_state) + self.assertEqual(output.shape[-1], self.units * 2) + def _testWithAttention(self, create_attention_mechanism, expected_final_output,