Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 11 additions & 13 deletions tensorflow_addons/seq2seq/attention_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ def __init__(self,
if not callable(probability_fn):
raise TypeError("probability_fn must be callable, saw type: %s" %
type(probability_fn).__name__)
self.default_probability_fn = probability_fn
self.probability_fn = probability_fn

self.keys = None
Expand Down Expand Up @@ -226,7 +227,7 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
else:
memory, memory_sequence_length = inputs, None
memory_mask = mask
self._setup_memory(memory, memory_sequence_length, memory_mask)
self.setup_memory(memory, memory_sequence_length, memory_mask)
# We force the self.built to false here since only memory is,
# initialized but the real query/state has not been call() yet. The
# layer should be build and call again.
Expand All @@ -248,10 +249,10 @@ def call(self, inputs, mask=None, setup_memory=False, **kwargs):
query, state = inputs[0], inputs[1]
return self._calculate_attention(query, state)

def _setup_memory(self,
memory,
memory_sequence_length=None,
memory_mask=None):
def setup_memory(self,
memory,
memory_sequence_length=None,
memory_mask=None):
"""Pre-process the memory before actually query the memory.

This should only be called once at the first invocation of call().
Expand All @@ -266,9 +267,6 @@ def _setup_memory(self,
max_time]`. For any value equal to False, the corresponding value
in memory should be ignored.
"""
if self._memory_initialized:
raise ValueError(
"The memory for the attention has already been setup.")
if memory_sequence_length is not None and memory_mask is not None:
raise ValueError(
"memory_sequence_length and memory_mask cannot be "
Expand All @@ -293,7 +291,7 @@ def _setup_memory(self,
self._alignments_size = (tf.compat.dimension_value(
self.keys.shape[1]) or tf.shape(self.keys)[1])
if memory_mask is not None or memory_sequence_length is not None:
unwrapped_probability_fn = self.probability_fn
unwrapped_probability_fn = self.default_probability_fn

def _mask_probability_fn(score, prev):
return unwrapped_probability_fn(
Expand Down Expand Up @@ -505,7 +503,7 @@ class LuongAttention(_BaseAttentionMechanism):

def __init__(self,
units,
memory,
memory=None,
memory_sequence_length=None,
scale=False,
probability_fn="softmax",
Expand Down Expand Up @@ -671,7 +669,7 @@ class BahdanauAttention(_BaseAttentionMechanism):

def __init__(self,
units,
memory,
memory=None,
memory_sequence_length=None,
normalize=False,
probability_fn="softmax",
Expand Down Expand Up @@ -1013,7 +1011,7 @@ class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):

def __init__(self,
units,
memory,
memory=None,
memory_sequence_length=None,
normalize=False,
sigmoid_noise=0.,
Expand Down Expand Up @@ -1186,7 +1184,7 @@ class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):

def __init__(self,
units,
memory,
memory=None,
memory_sequence_length=None,
scale=False,
sigmoid_noise=0.,
Expand Down
36 changes: 36 additions & 0 deletions tensorflow_addons/seq2seq/attention_wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def setUp(self):

self.memory = np.random.randn(self.batch, self.timestep,
self.memory_size).astype(np.float32)
self.memory_length = np.random.randint(
low=1, high=self.timestep + 1, size=(self.batch,))
self.query = np.random.randn(self.batch, self.units).astype(np.float32)
self.state = np.random.randn(self.batch,
self.timestep).astype(np.float32)
Expand Down Expand Up @@ -159,6 +161,40 @@ def test_save_load_layer(self, attention_cls):

self.assertAllClose(y_ref, y)

@parameterized.named_parameters(
("luong", wrapper.LuongAttention),
("luong_monotonic", wrapper.LuongMonotonicAttention),
("bahdanau", wrapper.BahdanauAttention),
("bahdanau_monotonic", wrapper.BahdanauMonotonicAttention),
)
def test_manual_memory_reset(self, attention_cls):
attention = attention_cls(self.units)

def _compute_score(batch_size=None):
if batch_size is None:
batch_size = self.batch
memory = self.memory[:batch_size]
attention.setup_memory(
memory, memory_sequence_length=self.memory_length[:batch_size])
self.assertListEqual(attention.values.shape.as_list(),
list(memory.shape))
self.assertListEqual(attention.keys.shape.as_list(),
list(memory.shape)[:-1] + [self.units])
return attention(
[self.query[:batch_size], self.state[:batch_size]])

score = _compute_score(batch_size=self.batch)
variables = list(attention.variables)
score = _compute_score(batch_size=self.batch - 1)

# No new variables were created.
for var_1, var_2 in zip(variables, list(attention.variables)):
self.assertIs(var_1, var_2)

# Score can be computed without errors.
self.evaluate(tf.compat.v1.global_variables_initializer())
self.evaluate(score)

def test_masking(self):
memory = tf.ones([4, 4, 5], dtype=tf.float32)
memory_sequence_length = tf.constant([1, 2, 3, 4], dtype=tf.int32)
Expand Down