Skip to content

Commit a45e619

Browse files
committed
adjust tolerance for griffinlim.
1 parent 42c4fb2 commit a45e619

File tree

1 file changed

+26
-3
lines changed

1 file changed

+26
-3
lines changed

test/test_functional.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def test_batch_griffinlim(self):
118118
n_iter = 32
119119
length = 1000
120120

121-
self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0)
121+
self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5)
122122

123123
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
124124
computed = F.compute_deltas(specgram, win_length=win_length)
@@ -506,6 +506,18 @@ def test_pitch(self):
506506

507507
def _test_batch_shape(self, functional, tensor, *args, **kwargs):
508508

509+
kwargs_compare = {}
510+
if 'atol' in kwargs:
511+
atol = kwargs['atol']
512+
del kwargs['atol']
513+
kwargs_compare['atol'] = atol
514+
print(kwargs)
515+
516+
if 'rtol' in kwargs:
517+
rtol = kwargs['rtol']
518+
del kwargs['rtol']
519+
kwargs_compare['rtol'] = rtol
520+
509521
# Single then transform then batch
510522

511523
expected = functional(tensor, *args, **kwargs)
@@ -516,14 +528,25 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs):
516528
tensors = tensor.unsqueeze(0).unsqueeze(0)
517529
computed = functional(tensors, *args, **kwargs)
518530

519-
self._compare_estimate(computed, expected)
531+
self._compare_estimate(computed, expected, **kwargs_compare)
520532

521533
return tensors, expected
522534

523535
def _test_batch(self, functional, tensor, *args, **kwargs):
524536

525537
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
526538

539+
kwargs_compare = {}
540+
if 'atol' in kwargs:
541+
atol = kwargs['atol']
542+
del kwargs['atol']
543+
kwargs_compare['atol'] = atol
544+
545+
if 'rtol' in kwargs:
546+
rtol = kwargs['rtol']
547+
del kwargs['rtol']
548+
kwargs_compare['rtol'] = rtol
549+
527550
# 3-Batch then transform
528551

529552
ind = [3] + [1] * (int(tensors.dim()) - 1)
@@ -534,7 +557,7 @@ def _test_batch(self, functional, tensor, *args, **kwargs):
534557

535558
computed = functional(tensors, *args, **kwargs)
536559

537-
self._compare_estimate(computed, expected)
560+
self._compare_estimate(computed, expected, **kwargs_compare)
538561

539562
def test_batch_mask_along_axis_iid(self):
540563

0 commit comments

Comments
 (0)