Skip to content

Commit 50453b5

Browse files
committed
function for batch test.
1 parent 53269f3 commit 50453b5

File tree

1 file changed

+61
-45
lines changed

1 file changed

+61
-45
lines changed

test/test_functional.py

Lines changed: 61 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -97,18 +97,21 @@ def test_griffinlim(self):
9797

9898
self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))
9999

100-
# test batch
100+
def test_batch_griffinlim(self):
101101

102-
# Single then transform then batch
103-
expected = ta_out.unsqueeze(0).repeat(3, 1, 1)
102+
tensor = torch.rand((1, 201, 6))
104103

105-
# Batch then transform
106-
specgram = specgram.unsqueeze(0).repeat(3, 1, 1, 1)
107-
computed = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize,
108-
n_iter, momentum, length, rand_init)
104+
n_fft = 400
105+
ws = 400
106+
hop = 200
107+
window = torch.hann_window(ws)
108+
power = 2
109+
normalize = False
110+
momentum = 0.99
111+
n_iter = 32
112+
length = 1000
109113

110-
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
111-
self.assertTrue(torch.allclose(computed, expected, atol=5e-5))
114+
self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0)
112115

113116
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
114117
computed = F.compute_deltas(specgram, win_length=win_length)
@@ -133,22 +136,17 @@ def test_compute_deltas_randn(self):
133136
win_length = 2 * 7 + 1
134137
specgram = torch.randn(channel, n_mfcc, time)
135138
computed = F.compute_deltas(specgram, win_length=win_length)
139+
136140
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
141+
137142
_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)
138143

139144
def test_batch_pitch(self):
140145
waveform, sample_rate = torchaudio.load(self.test_filepath)
146+
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)
141147

142-
# Single then transform then batch
143-
expected = F.detect_pitch_frequency(waveform, sample_rate)
144-
expected = expected.unsqueeze(0).repeat(3, 1, 1)
145-
146-
# Batch then transform
147-
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
148-
computed = F.detect_pitch_frequency(waveform, sample_rate)
149-
150-
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
151-
self.assertTrue(torch.allclose(computed, expected))
148+
def test_jit_pitch(self):
149+
waveform, sample_rate = torchaudio.load(self.test_filepath)
152150
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)
153151

154152
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
@@ -164,22 +162,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
164162
for data_size in self.data_sizes:
165163
for i in range(self.number_of_trials):
166164

167-
# Non-batch
168165
sound = common_utils.random_float_tensor(i, data_size)
169166

170167
stft = torch.stft(sound, **kwargs)
171168
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
172169

173170
self._compare_estimate(sound, estimate)
174171

175-
# Batch
176-
stft = torch.stft(sound, **kwargs)
177-
stft = stft.repeat(3, 1, 1, 1, 1)
178-
sound = sound.repeat(3, 1, 1)
179-
180-
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
181-
self._compare_estimate(sound, estimate)
182-
183172
def test_istft_is_inverse_of_stft1(self):
184173
# hann_window, centered, normalized, onesided
185174
kwargs1 = {
@@ -396,6 +385,16 @@ def test_linearity_of_istft4(self):
396385
data_size = (2, 7, 3, 2)
397386
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)
398387

388+
def test_batch_istft(self):
389+
390+
stft = torch.tensor([
391+
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
392+
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
393+
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
394+
])
395+
396+
self._test_batch(F.istft, stft, n_fft=4, length=4)
397+
399398
def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
400399
# Using a decorator here causes parametrize to fail on Python 2
401400
if not IMPORT_LIBROSA:
@@ -496,32 +495,49 @@ def test_pitch(self):
496495
self.assertFalse(s)
497496

498497
# Convert to stereo and batch for testing purposes
499-
freq = freq.repeat(3, 2, 1, 1)
500-
waveform = waveform.repeat(3, 2, 1, 1)
498+
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) # , atol=1e-5)
499+
500+
def _test_batch_shape(self, functional, tensor, *args, **kwargs):
501+
502+
# Single then transform then batch
503+
504+
expected = functional(tensor, *args, **kwargs)
505+
expected = expected.unsqueeze(0).unsqueeze(0)
506+
507+
# 1-Batch then transform
508+
509+
tensors = tensor.unsqueeze(0).unsqueeze(0)
510+
computed = functional(tensors, *args, **kwargs)
511+
512+
self._compare_estimate(computed, expected)
501513

502-
freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
514+
return tensors, expected
503515

504-
assert torch.allclose(freq, freq2, atol=1e-5)
516+
def _test_batch(self, functional, tensor, *args, **kwargs):
517+
518+
tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)
519+
520+
# 3-Batch then transform
521+
522+
ind = [3] + [1] * (int(tensors.dim()) - 1)
523+
tensors = tensor.repeat(*ind)
524+
525+
ind = [3] + [1] * (int(expected.dim()) - 1)
526+
expected = expected.repeat(*ind)
527+
528+
computed = functional(tensors, *args, **kwargs)
529+
530+
self._compare_estimate(computed, expected)
505531

506532
def test_batch_mask_along_axis_iid(self):
507533

508-
specgram = torch.randn(2, 5, 5)
534+
tensor = torch.rand(2, 5, 5)
535+
509536
mask_param = 2
510537
mask_value = 30.
511538
axis = 2
512539

513-
torch.manual_seed(42)
514-
515-
# Single then transform then batch
516-
expected = F.mask_along_axis_iid(specgram, mask_param=mask_param, mask_value=mask_value, axis=axis)
517-
expected = expected.unsqueeze(0).unsqueeze(0)
518-
519-
# Batch then transform
520-
specgrams = specgram.unsqueeze(0).unsqueeze(0)
521-
computed = F.mask_along_axis_iid(specgrams, mask_param=mask_param, mask_value=mask_value, axis=axis)
522-
523-
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
524-
self.assertTrue(torch.allclose(computed, expected))
540+
self._test_batch_shape(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis)
525541

526542

527543
def _num_stft_bins(signal_len, fft_len, hop_length, pad):

0 commit comments

Comments
 (0)