@@ -466,8 +466,7 @@ def _testWithMaybeMultiAttention(self,
466466 expected_final_alignment_history ,
467467 final_alignment_history_info )
468468
469- # TODO: #407 Float64 test is failing
470- @parameterized .parameters ([np .float32 ])
469+ @parameterized .parameters ([np .float32 , np .float64 ])
471470 def testBahdanauNormalizedDType (self , dtype ):
472471 encoder_outputs = self .encoder_outputs .astype (dtype )
473472 decoder_inputs = self .decoder_inputs .astype (dtype )
@@ -478,11 +477,12 @@ def testBahdanauNormalizedDType(self, dtype):
478477 normalize = True ,
479478 dtype = dtype )
480479 cell = keras .layers .LSTMCell (
481- self .units , recurrent_activation = "sigmoid" )
482- cell = wrapper .AttentionWrapper (cell , attention_mechanism )
480+ self .units , recurrent_activation = "sigmoid" , dtype = dtype )
481+ cell = wrapper .AttentionWrapper (cell , attention_mechanism , dtype = dtype )
483482
484483 sampler = sampler_py .TrainingSampler ()
485- my_decoder = basic_decoder .BasicDecoder (cell = cell , sampler = sampler )
484+ my_decoder = basic_decoder .BasicDecoder (
485+ cell = cell , sampler = sampler , dtype = dtype )
486486
487487 final_outputs , final_state , _ = my_decoder (
488488 decoder_inputs ,
@@ -493,8 +493,7 @@ def testBahdanauNormalizedDType(self, dtype):
493493 self .assertEqual (final_outputs .rnn_output .dtype , dtype )
494494 self .assertIsInstance (final_state , wrapper .AttentionWrapperState )
495495
496- # TODO: #407 Float64 test is failing
497- @parameterized .parameters ([np .float32 ])
496+ @parameterized .parameters ([np .float32 , np .float64 ])
498497 def testLuongScaledDType (self , dtype ):
499498 # Test case for GitHub issue 18099
500499 encoder_outputs = self .encoder_outputs .astype (dtype )
@@ -507,11 +506,12 @@ def testLuongScaledDType(self, dtype):
507506 dtype = dtype ,
508507 )
509508 cell = keras .layers .LSTMCell (
510- self .units , recurrent_activation = "sigmoid" )
511- cell = wrapper .AttentionWrapper (cell , attention_mechanism )
509+ self .units , recurrent_activation = "sigmoid" , dtype = dtype )
510+ cell = wrapper .AttentionWrapper (cell , attention_mechanism , dtype = dtype )
512511
513512 sampler = sampler_py .TrainingSampler ()
514- my_decoder = basic_decoder .BasicDecoder (cell = cell , sampler = sampler )
513+ my_decoder = basic_decoder .BasicDecoder (
514+ cell = cell , sampler = sampler , dtype = dtype )
515515
516516 final_outputs , final_state , _ = my_decoder (
517517 decoder_inputs ,
0 commit comments