Skip to content

Commit 13cf110

Browse files
guillaumeklnseanpmorgan
authored andcommitted
Improve API for resetting AttentionMechanism memory (#354)
* Improve API for resetting AttentionMechanism memory
1 parent d15764b commit 13cf110

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ def __init__(self,
116116
if not callable(probability_fn):
117117
raise TypeError("probability_fn must be callable, saw type: %s" %
118118
type(probability_fn).__name__)
119+
self.default_probability_fn = probability_fn
119120
self.probability_fn = probability_fn
120121

121122
self.keys = None
@@ -226,7 +227,7 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
226227
else:
227228
memory, memory_sequence_length = inputs, None
228229
memory_mask = mask
229-
self._setup_memory(memory, memory_sequence_length, memory_mask)
230+
self.setup_memory(memory, memory_sequence_length, memory_mask)
230231
# We force the self.built to false here since only memory is,
231232
# initialized but the real query/state has not been call() yet. The
232233
# layer should be build and call again.
@@ -248,10 +249,10 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
248249
query, state = inputs[0], inputs[1]
249250
return self._calculate_attention(query, state)
250251

251-
def _setup_memory(self,
252-
memory,
253-
memory_sequence_length=None,
254-
memory_mask=None):
252+
def setup_memory(self,
253+
memory,
254+
memory_sequence_length=None,
255+
memory_mask=None):
255256
"""Pre-process the memory before actually query the memory.
256257
257258
This should only be called once at the first invocation of call().
@@ -266,9 +267,6 @@ def _setup_memory(self,
266267
max_time]`. For any value equal to False, the corresponding value
267268
in memory should be ignored.
268269
"""
269-
if self._memory_initialized:
270-
raise ValueError(
271-
"The memory for the attention has already been setup.")
272270
if memory_sequence_length is not None and memory_mask is not None:
273271
raise ValueError(
274272
"memory_sequence_length and memory_mask cannot be "
@@ -293,7 +291,7 @@ def _setup_memory(self,
293291
self._alignments_size = (tf.compat.dimension_value(
294292
self.keys.shape[1]) or tf.shape(self.keys)[1])
295293
if memory_mask is not None or memory_sequence_length is not None:
296-
unwrapped_probability_fn = self.probability_fn
294+
unwrapped_probability_fn = self.default_probability_fn
297295

298296
def _mask_probability_fn(score, prev):
299297
return unwrapped_probability_fn(
@@ -505,7 +503,7 @@ class LuongAttention(_BaseAttentionMechanism):
505503

506504
def __init__(self,
507505
units,
508-
memory,
506+
memory=None,
509507
memory_sequence_length=None,
510508
scale=False,
511509
probability_fn="softmax",
@@ -671,7 +669,7 @@ class BahdanauAttention(_BaseAttentionMechanism):
671669

672670
def __init__(self,
673671
units,
674-
memory,
672+
memory=None,
675673
memory_sequence_length=None,
676674
normalize=False,
677675
probability_fn="softmax",
@@ -1013,7 +1011,7 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
10131011

10141012
def __init__(self,
10151013
units,
1016-
memory,
1014+
memory=None,
10171015
memory_sequence_length=None,
10181016
normalize=False,
10191017
sigmoid_noise=0.,
@@ -1186,7 +1184,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
11861184

11871185
def __init__(self,
11881186
units,
1189-
memory,
1187+
memory=None,
11901188
memory_sequence_length=None,
11911189
scale=False,
11921190
sigmoid_noise=0.,

tensorflow_addons/seq2seq/attention_wrapper_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ def setUp(self):
4444

4545
self.memory = np.random.randn(self.batch, self.timestep,
4646
self.memory_size).astype(np.float32)
47+
self.memory_length = np.random.randint(
48+
low=1, high=self.timestep + 1, size=(self.batch,))
4749
self.query = np.random.randn(self.batch, self.units).astype(np.float32)
4850
self.state = np.random.randn(self.batch,
4951
self.timestep).astype(np.float32)
@@ -159,6 +161,40 @@ def test_save_load_layer(self, attention_cls):
159161

160162
self.assertAllClose(y_ref, y)
161163

164+
@parameterized.named_parameters(
165+
("luong", wrapper.LuongAttention),
166+
("luong_monotonic", wrapper.LuongMonotonicAttention),
167+
("bahdanau", wrapper.BahdanauAttention),
168+
("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention),
169+
)
170+
def test_manual_memory_reset(self, attention_cls):
171+
attention = attention_cls(self.units)
172+
173+
def _compute_score(batch_size=None):
174+
if batch_size is None:
175+
batch_size = self.batch
176+
memory = self.memory[:batch_size]
177+
attention.setup_memory(
178+
memory, memory_sequence_length=self.memory_length[:batch_size])
179+
self.assertListEqual(attention.values.shape.as_list(),
180+
list(memory.shape))
181+
self.assertListEqual(attention.keys.shape.as_list(),
182+
list(memory.shape)[:-1] + [self.units])
183+
return attention(
184+
[self.query[:batch_size], self.state[:batch_size]])
185+
186+
score = _compute_score(batch_size=self.batch)
187+
variables = list(attention.variables)
188+
score = _compute_score(batch_size=self.batch - 1)
189+
190+
# No new variables were created.
191+
for var_1, var_2 in zip(variables, list(attention.variables)):
192+
self.assertIs(var_1, var_2)
193+
194+
# Score can be computed without errors.
195+
self.evaluate(tf.compat.v1.global_variables_initializer())
196+
self.evaluate(score)
197+
162198
def test_masking(self):
163199
memory = tf.ones([4, 4, 5], dtype=tf.float32)
164200
memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32)

0 commit comments

Comments
 (0)