Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,10 @@ def __call__(self, inputs, **kwargs):
inputs: the inputs tensors.
**kwargs: dict, other keyeword arguments for the `__call__()`
"""
# Allow manual memory reset
if kwargs.get('setup_memory', False):
self._memory_initialized = False

if self._memory_initialized:
if len(inputs) not in (2, 3):
raise ValueError(
Expand All @@ -188,6 +192,7 @@ def __call__(self, inputs, **kwargs):
# We append the calculated memory here so that the graph will be
# connected.
inputs.append(self.values)

return super(_BaseAttentionMechanism, self).__call__(inputs, **kwargs)

def call(self, inputs, mask=None, setup_memory=False, **kwargs):
Expand Down
50 changes: 49 additions & 1 deletion tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,55 @@ def test_masking(self):
alignment = self.evaluate(alignment)
self.assertEqual(np.sum(np.triu(alignment, k=1)), 0)

# TODO(scottzhu): Add tests for model.compile(run_eagerly=True)
@parameterized.named_parameters(
("luong", wrapper.LuongAttention),
("luong_monotonic", wrapper.LuongMonotonicAttention),
("bahdanau", wrapper.BahdanauAttention),
("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention),
)
def test_memory_re_setup(self, attention_cls):
class MyModel(tf.keras.models.Model):
def __init__(self, vocab, embedding_dim, memory_size, units):
super(MyModel, self).__init__()
self.emb = tf.keras.layers.Embedding(
vocab, embedding_dim, mask_zero=True)
self.encoder = tf.keras.layers.LSTM(
memory_size, return_sequences=True)
self.attn_mch = attention_cls(units)

def call(self, inputs):
enc_input, query, state = inputs
mask = self.emb.compute_mask(enc_input)
enc_input = self.emb(enc_input)
enc_output = self.encoder(enc_input, mask=mask)
# To ensure manual resetting also works in the graph mode,
# we call the attention mechanism twice.
self.attn_mch(enc_output, mask=mask, setup_memory=True)
self.attn_mch(enc_output, mask=mask, setup_memory=True)
score = self.attn_mch([query, state])
return score

vocab = 20
embedding_dim = 6
num_batches = 5

model = MyModel(vocab, embedding_dim, self.memory_size, self.units)
if tf.executing_eagerly():
model.compile("rmsprop", "mse", run_eagerly=True)
else:
model.compile("rmsprop", "mse")

x = np.random.randint(
vocab, size=(num_batches * self.batch, self.timestep))
x_test = np.random.randint(
vocab, size=(num_batches * self.batch, self.timestep))
y = np.random.randn(num_batches * self.batch, self.timestep)

query = np.tile(self.query, [num_batches, 1])
state = np.tile(self.state, [num_batches, 1])

model.fit([x, query, state], (y, y), batch_size=self.batch)
model.predict_on_batch([x_test, query, state])


class ResultSummary(
Expand Down