From fafc98d8a785655e49aba0ef16f6d450f1a81da6 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Mon, 15 Jul 2019 14:00:56 +0200 Subject: [PATCH 1/2] Improve API for resetting AttentionMechanism memory --- .../seq2seq/attention_wrapper.py | 24 ++++++------ .../seq2seq/attention_wrapper_test.py | 39 +++++++++++++++++++ 2 files changed, 50 insertions(+), 13 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 6345d4d43c..fd43aae54f 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -116,6 +116,7 @@ def __init__(self, if not callable(probability_fn): raise TypeError("probability_fn must be callable, saw type: %s" % type(probability_fn).__name__) + self.default_probability_fn = probability_fn self.probability_fn = probability_fn self.keys = None @@ -226,7 +227,7 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs): else: memory, memory_sequence_length = inputs, None memory_mask = mask - self._setup_memory(memory, memory_sequence_length, memory_mask) + self.setup_memory(memory, memory_sequence_length, memory_mask) # We force the self.built to false here since only memory is, # initialized but the real query/state has not been call() yet. The # layer should be build and call again. @@ -248,10 +249,10 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs): query, state = inputs[0], inputs[1] return self._calculate_attention(query, state) - def _setup_memory(self, - memory, - memory_sequence_length=None, - memory_mask=None): + def setup_memory(self, + memory, + memory_sequence_length=None, + memory_mask=None): """Pre-process the memory before actually query the memory. This should only be called once at the first invocation of call(). @@ -266,9 +267,6 @@ def _setup_memory(self, max_time]`. For any value equal to False, the corresponding value in memory should be ignored. """ - if self._memory_initialized: - raise ValueError( - "The memory for the attention has already been setup.") if memory_sequence_length is not None and memory_mask is not None: raise ValueError( "memory_sequence_length and memory_mask cannot be " @@ -293,7 +291,7 @@ def _setup_memory(self, self._alignments_size = (tf.compat.dimension_value( self.keys.shape[1]) or tf.shape(self.keys)[1]) if memory_mask is not None or memory_sequence_length is not None: - unwrapped_probability_fn = self.probability_fn + unwrapped_probability_fn = self.default_probability_fn def _mask_probability_fn(score, prev): return unwrapped_probability_fn( @@ -505,7 +503,7 @@ class LuongAttention(_BaseAttentionMechanism): def __init__(self, units, - memory, + memory=None, memory_sequence_length=None, scale=False, probability_fn="softmax", @@ -671,7 +669,7 @@ class BahdanauAttention(_BaseAttentionMechanism): def __init__(self, units, - memory, + memory=None, memory_sequence_length=None, normalize=False, probability_fn="softmax", @@ -1013,7 +1011,7 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism): def __init__(self, units, - memory, + memory=None, memory_sequence_length=None, normalize=False, sigmoid_noise=0., @@ -1186,7 +1184,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism): def __init__(self, units, - memory, + memory=None, memory_sequence_length=None, scale=False, sigmoid_noise=0., diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index d9538436aa..f692f66297 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -44,6 +44,8 @@ def setUp(self): self.memory = np.random.randn(self.batch, self.timestep, self.memory_size).astype(np.float32) + self.memory_length = np.random.randint( + low=1, high=self.timestep + 1, size=(self.batch,)) self.query = np.random.randn(self.batch, self.units).astype(np.float32) self.state = np.random.randn(self.batch, self.timestep).astype(np.float32) @@ -159,6 +161,43 @@ def test_save_load_layer(self, attention_cls): self.assertAllClose(y_ref, y) + @parameterized.named_parameters( + ("luong", wrapper.LuongAttention), + ("luong_monotonic", wrapper.LuongMonotonicAttention), + ("bahdanau", wrapper.BahdanauAttention), + ("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention), + ) + def test_manual_memory_reset(self, attention_cls): + attention = attention_cls(self.units) + + def _compute_score(batch_size=None): + if batch_size is None: + batch_size = self.batch + memory = self.memory[:batch_size] + attention.setup_memory( + memory, + memory_sequence_length=self.memory_length[:batch_size]) + self.assertListEqual( + attention.values.shape.as_list(), + list(memory.shape)) + self.assertListEqual( + attention.keys.shape.as_list(), + list(memory.shape)[:-1] + [self.units]) + return attention( + [self.query[:batch_size], self.state[:batch_size]]) + + score = _compute_score(batch_size=self.batch) + variables = list(attention.variables) + score = _compute_score(batch_size=self.batch - 1) + + # No new variables were created. + for var_1, var_2 in zip(variables, list(attention.variables)): + self.assertIs(var_1, var_2) + + # Score can be computed without errors. + self.evaluate(tf.compat.v1.global_variables_initializer()) + self.evaluate(score) + def test_masking(self): memory = tf.ones([4, 4, 5], dtype=tf.float32) memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32) From 9ae3d274b64f58a89714124682223e3893b2b3f7 Mon Sep 17 00:00:00 2001 From: Guillaume Klein Date: Wed, 17 Jul 2019 16:54:16 +0200 Subject: [PATCH 2/2] Apply recommended code formatting --- tensorflow_addons/seq2seq/attention_wrapper_test.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index f692f66297..e1fdd2a900 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -175,14 +175,11 @@ def _compute_score(batch_size=None): batch_size = self.batch memory = self.memory[:batch_size] attention.setup_memory( - memory, - memory_sequence_length=self.memory_length[:batch_size]) - self.assertListEqual( - attention.values.shape.as_list(), - list(memory.shape)) - self.assertListEqual( - attention.keys.shape.as_list(), - list(memory.shape)[:-1] + [self.units]) + memory, memory_sequence_length=self.memory_length[:batch_size]) + self.assertListEqual(attention.values.shape.as_list(), + list(memory.shape)) + self.assertListEqual(attention.keys.shape.as_list(), + list(memory.shape)[:-1] + [self.units]) return attention( [self.query[:batch_size], self.state[:batch_size]])