Skip to content

Commit 556641c

Browse files
WindQAQseanpmorgan
authored andcommitted
fix float64 tests (#419)
1 parent f73c338 commit 556641c

File tree

2 files changed

+14
-12
lines changed

2 files changed

+14
-12
lines changed

tensorflow_addons/seq2seq/attention_wrapper.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1538,7 +1538,8 @@ def __init__(self,
15381538
initial_cell_state=None,
15391539
name=None,
15401540
attention_layer=None,
1541-
attention_fn=None):
1541+
attention_fn=None,
1542+
**kwargs):
15421543
"""Construct the `AttentionWrapper`.
15431544
15441545
**NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped
@@ -1619,6 +1620,7 @@ def __init__(self,
16191620
attention_layer) and outputs (attention, alignments,
16201621
next_attention_state). If provided, the attention_layer_size should
16211622
be the size of the outputs of attention_fn.
1623+
**kwargs: Other keyword arguments for layer creation.
16221624
16231625
Raises:
16241626
TypeError: `attention_layer_size` is not None and
@@ -1629,7 +1631,7 @@ def __init__(self,
16291631
of `attention_layer_size`; if `attention_layer_size` and
16301632
`attention_layer` are set simultaneously.
16311633
"""
1632-
super(AttentionWrapper, self).__init__(name=name)
1634+
super(AttentionWrapper, self).__init__(name=name, **kwargs)
16331635
rnn_cell_impl.assert_like_rnncell("cell", cell)
16341636
if isinstance(attention_mechanism, (list, tuple)):
16351637
self._is_multi = True

tensorflow_addons/seq2seq/attention_wrapper_test.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -466,8 +466,7 @@ def _testWithMaybeMultiAttention(self,
466466
expected_final_alignment_history,
467467
final_alignment_history_info)
468468

469-
# TODO: #407 Float64 test is failing
470-
@parameterized.parameters([np.float32])
469+
@parameterized.parameters([np.float32, np.float64])
471470
def testBahdanauNormalizedDType(self, dtype):
472471
encoder_outputs = self.encoder_outputs.astype(dtype)
473472
decoder_inputs = self.decoder_inputs.astype(dtype)
@@ -478,11 +477,12 @@ def testBahdanauNormalizedDType(self, dtype):
478477
normalize=True,
479478
dtype=dtype)
480479
cell = keras.layers.LSTMCell(
481-
self.units, recurrent_activation="sigmoid")
482-
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
480+
self.units, recurrent_activation="sigmoid", dtype=dtype)
481+
cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)
483482

484483
sampler = sampler_py.TrainingSampler()
485-
my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
484+
my_decoder = basic_decoder.BasicDecoder(
485+
cell=cell, sampler=sampler, dtype=dtype)
486486

487487
final_outputs, final_state, _ = my_decoder(
488488
decoder_inputs,
@@ -493,8 +493,7 @@ def testBahdanauNormalizedDType(self, dtype):
493493
self.assertEqual(final_outputs.rnn_output.dtype, dtype)
494494
self.assertIsInstance(final_state, wrapper.AttentionWrapperState)
495495

496-
# TODO: #407 Float64 test is failing
497-
@parameterized.parameters([np.float32])
496+
@parameterized.parameters([np.float32, np.float64])
498497
def testLuongScaledDType(self, dtype):
499498
# Test case for GitHub issue 18099
500499
encoder_outputs = self.encoder_outputs.astype(dtype)
@@ -507,11 +506,12 @@ def testLuongScaledDType(self, dtype):
507506
dtype=dtype,
508507
)
509508
cell = keras.layers.LSTMCell(
510-
self.units, recurrent_activation="sigmoid")
511-
cell = wrapper.AttentionWrapper(cell, attention_mechanism)
509+
self.units, recurrent_activation="sigmoid", dtype=dtype)
510+
cell = wrapper.AttentionWrapper(cell, attention_mechanism, dtype=dtype)
512511

513512
sampler = sampler_py.TrainingSampler()
514-
my_decoder = basic_decoder.BasicDecoder(cell=cell, sampler=sampler)
513+
my_decoder = basic_decoder.BasicDecoder(
514+
cell=cell, sampler=sampler, dtype=dtype)
515515

516516
final_outputs, final_state, _ = my_decoder(
517517
decoder_inputs,

0 commit comments

Comments
 (0)