From 6baacc9269b2a370507206c0f9cb00e0264f8e80 Mon Sep 17 00:00:00 2001 From: AmirHosein KazemNejad Date: Sat, 28 Sep 2019 15:31:31 +0330 Subject: [PATCH 1/2] Add support for manual memory reset --- .../seq2seq/attention_wrapper.py | 5 ++ .../seq2seq/attention_wrapper_test.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index ac4c57b250..191c5a3d1c 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -179,6 +179,10 @@ def __call__(self, inputs, **kwargs): inputs: the inputs tensors. **kwargs: dict, other keyeword arguments for the `__call__()` """ + # Allow manual memory reset + if kwargs.get('setup_memory', False): + self._memory_initialized = False + if self._memory_initialized: if len(inputs) not in (2, 3): raise ValueError( @@ -188,6 +192,7 @@ def __call__(self, inputs, **kwargs): # We append the calculated memory here so that the graph will be # connected. inputs.append(self.values) + return super(_BaseAttentionMechanism, self).__call__(inputs, **kwargs) def call(self, inputs, mask=None, setup_memory=False, **kwargs): diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index fcb955da9f..36ac9b9edc 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -209,6 +209,57 @@ def test_masking(self): alignment = self.evaluate(alignment) self.assertEqual(np.sum(np.triu(alignment, k=1)), 0) + @parameterized.named_parameters( + ("luong", wrapper.LuongAttention), + ("luong_monotonic", wrapper.LuongMonotonicAttention), + ("bahdanau", wrapper.BahdanauAttention), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention), + ) + def test_memory_re_setup(self, attention_cls): + class MyModel(tf.keras.models.Model): + def __init__(self): + super(MyModel, self).__init__() + self.emb = tf.keras.layers.Embedding( + vocab, embedding_dim, mask_zero=True) + self.encoder = tf.keras.layers.LSTM( + test_class_self.memory_size, return_sequences=True) + self.attn_mch = attention_cls(test_class_self.units) + + def call(self, inputs): + enc_input, query, state = inputs + mask = self.emb.compute_mask(enc_input) + enc_input = self.emb(enc_input) + enc_output = self.encoder(enc_input, mask=mask) + # To ensure manual resetting also works in the graph mode, + # we call the attention mechanism twice. + self.attn_mch(enc_output, mask=mask, setup_memory=True) + self.attn_mch(enc_output, mask=mask, setup_memory=True) + score = self.attn_mch([query, state]) + return score + + test_class_self = self + vocab = 20 + embedding_dim = 6 + num_batches = 5 + + model = MyModel() + if tf.executing_eagerly(): + model.compile("rmsprop", "mse", run_eagerly=True) + else: + model.compile("rmsprop", "mse") + + x = np.random.randint( + vocab, size=(num_batches * self.batch, self.timestep)) + x_test = np.random.randint( + vocab, size=(num_batches * self.batch, self.timestep)) + y = np.random.randn(num_batches * self.batch, self.timestep) + + query = np.tile(self.query, [num_batches, 1]) + state = np.tile(self.state, [num_batches, 1]) + + model.fit([x, query, state], (y, y), batch_size=self.batch) + model.predict_on_batch([x_test, query, state]) + # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) From 3a182b50680ee57cb894347ff3ac1cafdd68ad54 Mon Sep 17 00:00:00 2001 From: AmirHosein KazemNejad Date: Mon, 30 Sep 2019 19:40:24 +0330 Subject: [PATCH 2/2] Better code style in MyModel class Remove solved #TODO --- tensorflow_addons/seq2seq/attention_wrapper_test.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 36ac9b9edc..7705e22bd0 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -217,13 +217,13 @@ def test_masking(self): ) def test_memory_re_setup(self, attention_cls): class MyModel(tf.keras.models.Model): - def __init__(self): + def __init__(self, vocab, embedding_dim, memory_size, units): super(MyModel, self).__init__() self.emb = tf.keras.layers.Embedding( vocab, embedding_dim, mask_zero=True) self.encoder = tf.keras.layers.LSTM( - test_class_self.memory_size, return_sequences=True) - self.attn_mch = attention_cls(test_class_self.units) + memory_size, return_sequences=True) + self.attn_mch = attention_cls(units) def call(self, inputs): enc_input, query, state = inputs @@ -237,12 +237,11 @@ def call(self, inputs): score = self.attn_mch([query, state]) return score - test_class_self = self vocab = 20 embedding_dim = 6 num_batches = 5 - model = MyModel() + model = MyModel(vocab, embedding_dim, self.memory_size, self.units) if tf.executing_eagerly(): model.compile("rmsprop", "mse", run_eagerly=True) else: @@ -260,8 +259,6 @@ def call(self, inputs): model.fit([x, query, state], (y, y), batch_size=self.batch) model.predict_on_batch([x_test, query, state]) - # TODO(scottzhu): Add tests for model.compile(run_eagerly=True) - class ResultSummary( collections.namedtuple("ResultSummary", ("shape", "dtype", "mean"))):