Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 41 additions & 14 deletions tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,13 @@ def __init__(self,
self.values = super(_BaseAttentionMechanism, self).__call__(
inputs, setup_memory=True)

@property
def memory_initialized(self):
"""Returns `True` if this attention mechanism has been initialized with
a memory.
"""
return self._memory_initialized

def build(self, input_shape):
if not self._memory_initialized:
# This is for setting up the memory, which contains memory and
Expand Down Expand Up @@ -1680,7 +1687,6 @@ def __init__(self,
use_bias=False,
dtype=attention_mechanisms[i].dtype) for i,
attention_layer_size in enumerate(attention_layer_sizes))
self._attention_layer_size = sum(attention_layer_sizes)
elif attention_layer is not None:
self._attention_layers = list(
attention_layer if isinstance(attention_layer, (
Expand All @@ -1690,22 +1696,13 @@ def __init__(self,
"If provided, attention_layer must contain exactly one "
"layer per attention_mechanism, saw: %d vs %d" % (len(
self._attention_layers), len(attention_mechanisms)))
self._attention_layer_size = sum(
tf.compat.dimension_value(
layer.compute_output_shape([
None, cell.output_size +
tf.compat.dimension_value(mechanism.values.shape[-1])
])[-1]) for layer, mechanism in zip(
self._attention_layers, attention_mechanisms))
else:
self._attention_layers = None
self._attention_layer_size = sum(
tf.compat.dimension_value(attention_mechanism.values.shape[-1])
for attention_mechanism in attention_mechanisms)

if attention_fn is None:
attention_fn = _compute_attention
self._attention_fn = attention_fn
self._attention_layer_size = None

self._cell = cell
self._attention_mechanisms = attention_mechanisms
Expand Down Expand Up @@ -1735,7 +1732,17 @@ def __init__(self,
s, name="check_initial_cell_state"),
initial_cell_state)

def _attention_mechanisms_checks(self):
for attention_mechanism in self._attention_mechanisms:
if not attention_mechanism.memory_initialized:
raise ValueError("The AttentionMechanism instances passed to "
"this AttentionWrapper should be initialized "
"with a memory first, either by passing it "
"to the AttentionMechanism constructor or "
"calling attention_mechanism.setup_memory()")

def _batch_size_checks(self, batch_size, error_message):
self._attention_mechanisms_checks()
return [
tf.compat.v1.assert_equal(
batch_size,
Expand All @@ -1744,6 +1751,26 @@ def _batch_size_checks(self, batch_size, error_message):
for attention_mechanism in self._attention_mechanisms
]

def _get_attention_layer_size(self):
if self._attention_layer_size is not None:
return self._attention_layer_size
self._attention_mechanisms_checks()
attention_output_sizes = (
attention_mechanism.values.shape[-1]
for attention_mechanism in self._attention_mechanisms)
if self._attention_layers is None:
self._attention_layer_size = sum(attention_output_sizes)
else:
# Compute the layer output size from its input which is the
# concatenation of the cell output and the attention mechanism
# output.
self._attention_layer_size = sum(
layer.compute_output_shape(
[None, self._cell.output_size + attention_output_size])[-1]
for layer, attention_output_size in zip(
self._attention_layers, attention_output_sizes))
return self._attention_layer_size

def _item_or_tuple(self, seq):
"""Returns `seq` as tuple or the singular element.

Expand All @@ -1767,7 +1794,7 @@ def _item_or_tuple(self, seq):
@property
def output_size(self):
if self._output_attention:
return self._attention_layer_size
return self._get_attention_layer_size()
else:
return self._cell.output_size

Expand All @@ -1782,7 +1809,7 @@ def state_size(self):
return AttentionWrapperState(
cell_state=self._cell.state_size,
time=tf.TensorShape([]),
attention=self._attention_layer_size,
attention=self._get_attention_layer_size(),
alignments=self._item_or_tuple(
a.alignments_size for a in self._attention_mechanisms),
attention_state=self._item_or_tuple(
Expand Down Expand Up @@ -1841,7 +1868,7 @@ def get_initial_state(self, inputs=None, batch_size=None, dtype=None):
return AttentionWrapperState(
cell_state=cell_state,
time=tf.zeros([], dtype=tf.int32),
attention=_zero_state_tensors(self._attention_layer_size,
attention=_zero_state_tensors(self._get_attention_layer_size(),
batch_size, dtype),
alignments=self._item_or_tuple(initial_alignments),
attention_state=self._item_or_tuple(
Expand Down
22 changes: 22 additions & 0 deletions tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,28 @@ def setUp(self):
self.decoder_sequence_length = np.random.randint(
self.decoder_timestep, size=(self.batch,)).astype(np.int32)

def testCustomAttentionLayer(self):
attention_mechanism = wrapper.LuongAttention(self.units)
cell = tf.keras.layers.LSTMCell(self.units)
attention_layer = tf.keras.layers.Dense(
self.units * 2, use_bias=False, activation=tf.math.tanh)
attention_wrapper = wrapper.AttentionWrapper(
cell, attention_mechanism, attention_layer=attention_layer)
with self.assertRaises(ValueError):
# Should fail because the attention mechanism has not been
# initialized.
attention_wrapper.get_initial_state(
batch_size=self.batch, dtype=tf.float32)
attention_mechanism.setup_memory(
self.encoder_outputs.astype(np.float32),
memory_sequence_length=self.encoder_sequence_length)
initial_state = attention_wrapper.get_initial_state(
batch_size=self.batch, dtype=tf.float32)
self.assertEqual(initial_state.attention.shape[-1], self.units * 2)
first_input = self.decoder_inputs[:, 0].astype(np.float32)
output, next_state = attention_wrapper(first_input, initial_state)
self.assertEqual(output.shape[-1], self.units * 2)

def _testWithAttention(self,
create_attention_mechanism,
expected_final_output,
Expand Down