diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index ac4c57b250..191c5a3d1c 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -179,6 +179,10 @@ def __call__(self, inputs, **kwargs): inputs: the inputs tensors. **kwargs: dict, other keyeword arguments for the `__call__()` """ + # Allow manual memory reset + if kwargs.get('setup_memory', False): + self._memory_initialized = False + if self._memory_initialized: if len(inputs) not in (2, 3): raise ValueError( @@ -188,6 +192,7 @@ def __call__(self, inputs, **kwargs): # We append the calculated memory here so that the graph will be # connected. inputs.append(self.values) + return super(_BaseAttentionMechanism, self).__call__(inputs, **kwargs) def call(self, inputs, mask=None, setup_memory=False, **kwargs): diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index fcb955da9f..7705e22bd0 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -209,7 +209,55 @@ def test_masking(self): alignment = self.evaluate(alignment) self.assertEqual(np.sum(np.triu(alignment, k=1)), 0) - # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) + @parameterized.named_parameters( + ("luong", wrapper.LuongAttention), + ("luong_monotonic", wrapper.LuongMonotonicAttention), + ("bahdanau", wrapper.BahdanauAttention), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention), + ) + def test_memory_re_setup(self, attention_cls): + class MyModel(tf.keras.models.Model): + def __init__(self, vocab, embedding_dim, memory_size, units): + super(MyModel, self).__init__() + self.emb = tf.keras.layers.Embedding( + vocab, embedding_dim, mask_zero=True) + self.encoder = tf.keras.layers.LSTM( + memory_size, return_sequences=True) + self.attn_mch = attention_cls(units) + + def call(self, inputs): + enc_input, query, state = inputs + mask = self.emb.compute_mask(enc_input) + enc_input = self.emb(enc_input) + enc_output = self.encoder(enc_input, mask=mask) + # To ensure manual resetting also works in the graph mode, + # we call the attention mechanism twice. + self.attn_mch(enc_output, mask=mask, setup_memory=True) + self.attn_mch(enc_output, mask=mask, setup_memory=True) + score = self.attn_mch([query, state]) + return score + + vocab = 20 + embedding_dim = 6 + num_batches = 5 + + model = MyModel(vocab, embedding_dim, self.memory_size, self.units) + if tf.executing_eagerly(): + model.compile("rmsprop", "mse", run_eagerly=True) + else: + model.compile("rmsprop", "mse") + + x = np.random.randint( + vocab, size=(num_batches * self.batch, self.timestep)) + x_test = np.random.randint( + vocab, size=(num_batches * self.batch, self.timestep)) + y = np.random.randn(num_batches * self.batch, self.timestep) + + query = np.tile(self.query, [num_batches, 1]) + state = np.tile(self.state, [num_batches, 1]) + + model.fit([x, query, state], (y, y), batch_size=self.batch) + model.predict_on_batch([x_test, query, state]) class ResultSummary(