diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 3abe62bd8c..d8b46cc7c5 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -124,7 +124,6 @@ def __init__(self, self._memory_initialized = False self._check_inner_dims_defined = True self.supports_masking = True - self.score_mask_value = tf.as_dtype(self.dtype).as_numpy_dtype(-np.inf) if memory is not None: # Setup the memory by self.__call__() with memory and @@ -302,7 +301,7 @@ def _mask_probability_fn(score, prev): score, memory_mask=memory_mask, memory_sequence_length=memory_sequence_length, - score_mask_value=self.score_mask_value), prev) + score_mask_value=score.dtype.min), prev) self.probability_fn = _mask_probability_fn self._memory_initialized = True