Skip to content

Commit 64551a6

Browse files
authored
Apply misc updates to functional/batch_consistency_test.py (#1341)
* Parameterize `test_sliding_window_cmn` * Extract test naming function * Pass a spectrogram to `F.sliding_window_cmn` * Set manual seed for remaining rand calls in suite
1 parent 9a96fb7 commit 64551a6

File tree

1 file changed

+18
-10
lines changed

1 file changed

+18
-10
lines changed

test/torchaudio_unittest/functional/batch_consistency_test.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
from torchaudio_unittest import common_utils
1010

1111

12+
def _name_from_args(func, _, params):
13+
"""Return a parameterized test name, based on parameter values."""
14+
return "{}_{}".format(
15+
func.__name__,
16+
"_".join(str(arg) for arg in params.args))
17+
18+
1219
@parameterized_class([
1320
# Single-item batch isolates problems that come purely from adding a
1421
# dimension (rather than processing multiple items)
@@ -58,7 +65,7 @@ def test_griffinlim(self):
5865
@parameterized.expand(list(itertools.product(
5966
[8000, 16000, 44100],
6067
[1, 2],
61-
)), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}')
68+
)), name_func=_name_from_args)
6269
def test_detect_pitch_frequency(self, sample_rate, n_channels):
6370
# Use different frequencies to ensure each item in the batch returns a
6471
# different answer.
@@ -180,16 +187,16 @@ def test_flanger(self):
180187
sample_rate = 44100
181188
self.assert_batch_consistency(F.flanger, waveforms, sample_rate)
182189

183-
def test_sliding_window_cmn(self):
184-
waveforms = torch.randn(self.batch_size, 2, 1024) - 0.5
185-
self.assert_batch_consistency(
186-
F.sliding_window_cmn, waveforms, center=True, norm_vars=True)
187-
self.assert_batch_consistency(
188-
F.sliding_window_cmn, waveforms, center=True, norm_vars=False)
189-
self.assert_batch_consistency(
190-
F.sliding_window_cmn, waveforms, center=False, norm_vars=True)
190+
@parameterized.expand(list(itertools.product(
191+
[True, False], # center
192+
[True, False], # norm_vars
193+
)), name_func=_name_from_args)
194+
def test_sliding_window_cmn(self, center, norm_vars):
195+
torch.manual_seed(0)
196+
spectrogram = torch.rand(self.batch_size, 2, 1024, 1024) * 200
191197
self.assert_batch_consistency(
192-
F.sliding_window_cmn, waveforms, center=False, norm_vars=False)
198+
F.sliding_window_cmn, spectrogram, center=center,
199+
norm_vars=norm_vars)
193200

194201
def test_vad_from_file(self):
195202
filepath = common_utils.get_asset_path("vad-go-stereo-44100.wav")
@@ -202,6 +209,7 @@ def test_vad_from_file(self):
202209
def test_vad_different_items(self):
203210
"""Separate test to ensure VAD consistency with differing items."""
204211
sample_rate = 44100
212+
torch.manual_seed(0)
205213
waveforms = torch.rand(self.batch_size, 2, 100) - 0.5
206214
self.assert_batch_consistency(
207215
F.vad, waveforms, sample_rate=sample_rate)

0 commit comments

Comments
 (0)