@@ -97,18 +97,21 @@ def test_griffinlim(self):
9797
9898 self .assertTrue (torch .allclose (ta_out , lr_out , atol = 5e-5 ))
9999
100- # test batch
100+ def test_batch_griffinlim ( self ):
101101
102- # Single then transform then batch
103- expected = ta_out .unsqueeze (0 ).repeat (3 , 1 , 1 )
102+ tensor = torch .rand ((1 , 201 , 6 ))
104103
105- # Batch then transform
106- specgram = specgram .unsqueeze (0 ).repeat (3 , 1 , 1 , 1 )
107- computed = F .griffinlim (specgram , window , n_fft , hop , ws , 1 , normalize ,
108- n_iter , momentum , length , rand_init )
104+ n_fft = 400
105+ ws = 400
106+ hop = 200
107+ window = torch .hann_window (ws )
108+ power = 2
109+ normalize = False
110+ momentum = 0.99
111+ n_iter = 32
112+ length = 1000
109113
110- self .assertTrue (computed .shape == expected .shape , (computed .shape , expected .shape ))
111- self .assertTrue (torch .allclose (computed , expected , atol = 5e-5 ))
114+ self ._test_batch (F .griffinlim , tensor , window , n_fft , hop , ws , power , normalize , n_iter , momentum , length , 0 )
112115
113116 def _test_compute_deltas (self , specgram , expected , win_length = 3 , atol = 1e-6 , rtol = 1e-8 ):
114117 computed = F .compute_deltas (specgram , win_length = win_length )
@@ -133,22 +136,17 @@ def test_compute_deltas_randn(self):
133136 win_length = 2 * 7 + 1
134137 specgram = torch .randn (channel , n_mfcc , time )
135138 computed = F .compute_deltas (specgram , win_length = win_length )
139+
136140 self .assertTrue (computed .shape == specgram .shape , (computed .shape , specgram .shape ))
141+
137142 _test_torchscript_functional (F .compute_deltas , specgram , win_length = win_length )
138143
139144 def test_batch_pitch (self ):
140145 waveform , sample_rate = torchaudio .load (self .test_filepath )
146+ self ._test_batch (F .detect_pitch_frequency , waveform , sample_rate )
141147
142- # Single then transform then batch
143- expected = F .detect_pitch_frequency (waveform , sample_rate )
144- expected = expected .unsqueeze (0 ).repeat (3 , 1 , 1 )
145-
146- # Batch then transform
147- waveform = waveform .unsqueeze (0 ).repeat (3 , 1 , 1 )
148- computed = F .detect_pitch_frequency (waveform , sample_rate )
149-
150- self .assertTrue (computed .shape == expected .shape , (computed .shape , expected .shape ))
151- self .assertTrue (torch .allclose (computed , expected ))
148+ def test_jit_pitch (self ):
149+ waveform , sample_rate = torchaudio .load (self .test_filepath )
152150 _test_torchscript_functional (F .detect_pitch_frequency , waveform , sample_rate )
153151
154152 def _compare_estimate (self , sound , estimate , atol = 1e-6 , rtol = 1e-8 ):
@@ -164,22 +162,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
164162 for data_size in self .data_sizes :
165163 for i in range (self .number_of_trials ):
166164
167- # Non-batch
168165 sound = common_utils .random_float_tensor (i , data_size )
169166
170167 stft = torch .stft (sound , ** kwargs )
171168 estimate = torchaudio .functional .istft (stft , length = sound .size (1 ), ** kwargs )
172169
173170 self ._compare_estimate (sound , estimate )
174171
175- # Batch
176- stft = torch .stft (sound , ** kwargs )
177- stft = stft .repeat (3 , 1 , 1 , 1 , 1 )
178- sound = sound .repeat (3 , 1 , 1 )
179-
180- estimate = torchaudio .functional .istft (stft , length = sound .size (1 ), ** kwargs )
181- self ._compare_estimate (sound , estimate )
182-
183172 def test_istft_is_inverse_of_stft1 (self ):
184173 # hann_window, centered, normalized, onesided
185174 kwargs1 = {
@@ -396,6 +385,16 @@ def test_linearity_of_istft4(self):
396385 data_size = (2 , 7 , 3 , 2 )
397386 self ._test_linearity_of_istft (data_size , kwargs4 , atol = 1e-5 , rtol = 1e-8 )
398387
388+ def test_batch_istft (self ):
389+
390+ stft = torch .tensor ([
391+ [[4. , 0. ], [4. , 0. ], [4. , 0. ], [4. , 0. ], [4. , 0. ]],
392+ [[0. , 0. ], [0. , 0. ], [0. , 0. ], [0. , 0. ], [0. , 0. ]],
393+ [[0. , 0. ], [0. , 0. ], [0. , 0. ], [0. , 0. ], [0. , 0. ]]
394+ ])
395+
396+ self ._test_batch (F .istft , stft , n_fft = 4 , length = 4 )
397+
399398 def _test_create_fb (self , n_mels = 40 , sample_rate = 22050 , n_fft = 2048 , fmin = 0.0 , fmax = 8000.0 ):
400399 # Using a decorator here causes parametrize to fail on Python 2
401400 if not IMPORT_LIBROSA :
@@ -496,32 +495,49 @@ def test_pitch(self):
496495 self .assertFalse (s )
497496
498497 # Convert to stereo and batch for testing purposes
499- freq = freq .repeat (3 , 2 , 1 , 1 )
500- waveform = waveform .repeat (3 , 2 , 1 , 1 )
498+ self ._test_batch (F .detect_pitch_frequency , waveform , sample_rate ) # , atol=1e-5)
499+
500+ def _test_batch_shape (self , functional , tensor , * args , ** kwargs ):
501+
502+ # Single then transform then batch
503+
504+ expected = functional (tensor , * args , ** kwargs )
505+ expected = expected .unsqueeze (0 ).unsqueeze (0 )
506+
507+ # 1-Batch then transform
508+
509+ tensors = tensor .unsqueeze (0 ).unsqueeze (0 )
510+ computed = functional (tensors , * args , ** kwargs )
511+
512+ self ._compare_estimate (computed , expected )
501513
502- freq2 = torchaudio . functional . detect_pitch_frequency ( waveform , sample_rate )
514+ return tensors , expected
503515
504- assert torch .allclose (freq , freq2 , atol = 1e-5 )
516+ def _test_batch (self , functional , tensor , * args , ** kwargs ):
517+
518+ tensors , expected = self ._test_batch_shape (functional , tensor , * args , ** kwargs )
519+
520+ # 3-Batch then transform
521+
522+ ind = [3 ] + [1 ] * (int (tensors .dim ()) - 1 )
523+ tensors = tensor .repeat (* ind )
524+
525+ ind = [3 ] + [1 ] * (int (expected .dim ()) - 1 )
526+ expected = expected .repeat (* ind )
527+
528+ computed = functional (tensors , * args , ** kwargs )
529+
530+ self ._compare_estimate (computed , expected )
505531
506532 def test_batch_mask_along_axis_iid (self ):
507533
508- specgram = torch .randn (2 , 5 , 5 )
534+ tensor = torch .rand (2 , 5 , 5 )
535+
509536 mask_param = 2
510537 mask_value = 30.
511538 axis = 2
512539
513- torch .manual_seed (42 )
514-
515- # Single then transform then batch
516- expected = F .mask_along_axis_iid (specgram , mask_param = mask_param , mask_value = mask_value , axis = axis )
517- expected = expected .unsqueeze (0 ).unsqueeze (0 )
518-
519- # Batch then transform
520- specgrams = specgram .unsqueeze (0 ).unsqueeze (0 )
521- computed = F .mask_along_axis_iid (specgrams , mask_param = mask_param , mask_value = mask_value , axis = axis )
522-
523- self .assertTrue (computed .shape == expected .shape , (computed .shape , expected .shape ))
524- self .assertTrue (torch .allclose (computed , expected ))
540+ self ._test_batch_shape (F .mask_along_axis_iid , tensor , mask_param = mask_param , mask_value = mask_value , axis = axis )
525541
526542
527543def _num_stft_bins (signal_len , fft_len , hop_length , pad ):
0 commit comments