@@ -118,7 +118,7 @@ def test_batch_griffinlim(self):
118118 n_iter = 32
119119 length = 1000
120120
121- self ._test_batch (F .griffinlim , tensor , window , n_fft , hop , ws , power , normalize , n_iter , momentum , length , 0 )
121+ self ._test_batch (F .griffinlim , tensor , window , n_fft , hop , ws , power , normalize , n_iter , momentum , length , 0 , atol = 5e-5 )
122122
123123 def _test_compute_deltas (self , specgram , expected , win_length = 3 , atol = 1e-6 , rtol = 1e-8 ):
124124 computed = F .compute_deltas (specgram , win_length = win_length )
@@ -506,6 +506,18 @@ def test_pitch(self):
506506
507507 def _test_batch_shape (self , functional , tensor , * args , ** kwargs ):
508508
509+ kwargs_compare = {}
510+ if 'atol' in kwargs :
511+ atol = kwargs ['atol' ]
512+ del kwargs ['atol' ]
513+ kwargs_compare ['atol' ] = atol
514+ print (kwargs )
515+
516+ if 'rtol' in kwargs :
517+ rtol = kwargs ['rtol' ]
518+ del kwargs ['rtol' ]
519+ kwargs_compare ['rtol' ] = rtol
520+
509521 # Single then transform then batch
510522
511523 expected = functional (tensor , * args , ** kwargs )
@@ -516,14 +528,25 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
516528 tensors = tensor .unsqueeze (0 ).unsqueeze (0 )
517529 computed = functional (tensors , * args , ** kwargs )
518530
519- self ._compare_estimate (computed , expected )
531+ self ._compare_estimate (computed , expected , ** kwargs_compare )
520532
521533 return tensors , expected
522534
523535 def _test_batch (self , functional , tensor , * args , ** kwargs ):
524536
525537 tensors , expected = self ._test_batch_shape (functional , tensor , * args , ** kwargs )
526538
539+ kwargs_compare = {}
540+ if 'atol' in kwargs :
541+ atol = kwargs ['atol' ]
542+ del kwargs ['atol' ]
543+ kwargs_compare ['atol' ] = atol
544+
545+ if 'rtol' in kwargs :
546+ rtol = kwargs ['rtol' ]
547+ del kwargs ['rtol' ]
548+ kwargs_compare ['rtol' ] = rtol
549+
527550 # 3-Batch then transform
528551
529552 ind = [3 ] + [1 ] * (int (tensors .dim ()) - 1 )
@@ -534,7 +557,7 @@ def _test_batch(self, functional, tensor, *args, **kwargs):
534557
535558 computed = functional (tensors , * args , ** kwargs )
536559
537- self ._compare_estimate (computed , expected )
560+ self ._compare_estimate (computed , expected , ** kwargs_compare )
538561
539562 def test_batch_mask_along_axis_iid (self ):
540563
0 commit comments