Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 85 additions & 29 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,25 @@ def test_griffinlim(self):

self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5))

def test_batch_griffinlim(self):

torch.random.manual_seed(42)
tensor = torch.rand((1, 201, 6))

n_fft = 400
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these parameters special to the batch test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using the same as test_torchscript_spectrogram, but we can change them or we could set them in the class directly.

ws = 400
hop = 200
window = torch.hann_window(ws)
power = 2
normalize = False
momentum = 0.99
n_iter = 32
length = 1000

self._test_batch(
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5
)

def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
computed = F.compute_deltas(specgram, win_length=win_length)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
Expand All @@ -126,22 +145,17 @@ def test_compute_deltas_randn(self):
win_length = 2 * 7 + 1
specgram = torch.randn(channel, n_mfcc, time)
computed = F.compute_deltas(specgram, win_length=win_length)

self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

_test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length)

def test_batch_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)

# Single then transform then batch
expected = F.detect_pitch_frequency(waveform, sample_rate)
expected = expected.unsqueeze(0).repeat(3, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = F.detect_pitch_frequency(waveform, sample_rate)

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))
def test_jit_pitch(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)
_test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate)

def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
Expand All @@ -157,22 +171,13 @@ def _test_istft_is_inverse_of_stft(self, kwargs):
for data_size in self.data_sizes:
for i in range(self.number_of_trials):

# Non-batch
sound = common_utils.random_float_tensor(i, data_size)

stft = torch.stft(sound, **kwargs)
estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)

self._compare_estimate(sound, estimate)

# Batch
stft = torch.stft(sound, **kwargs)
stft = stft.repeat(3, 1, 1, 1, 1)
sound = sound.repeat(3, 1, 1)

estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs)
self._compare_estimate(sound, estimate)

def test_istft_is_inverse_of_stft1(self):
# hann_window, centered, normalized, onesided
kwargs1 = {
Expand Down Expand Up @@ -389,6 +394,16 @@ def test_linearity_of_istft4(self):
data_size = (2, 7, 3, 2)
self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8)

def test_batch_istft(self):

stft = torch.tensor([
[[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]],
[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why repeat all 0s instead of something more interesting (maybe just [0., 4.])

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm reusing the parameters from test_istft_of_ones.

[[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]]
])

self._test_batch(F.istft, stft, n_fft=4, length=4)

def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0):
# Using a decorator here causes parametrize to fail on Python 2
if not IMPORT_LIBROSA:
Expand Down Expand Up @@ -489,22 +504,63 @@ def test_pitch(self):
self.assertFalse(s)

# Convert to stereo and batch for testing purposes
freq = freq.repeat(3, 2, 1, 1)
waveform = waveform.repeat(3, 2, 1, 1)
self._test_batch(F.detect_pitch_frequency, waveform, sample_rate)

freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
def _test_batch_shape(self, functional, tensor, *args, **kwargs):

assert torch.allclose(freq, freq2, atol=1e-5)
kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

def _test_batch(self, functional):
waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100
if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol

# Single then transform then batch
expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1)

# Batch then transform
waveform = waveform.unsqueeze(0).repeat(3, 1, 1)
computed = functional(waveform)
torch.random.manual_seed(42)
expected = functional(tensor.clone(), *args, **kwargs)
expected = expected.unsqueeze(0).unsqueeze(0)

# 1-Batch then transform

tensors = tensor.unsqueeze(0).unsqueeze(0)

torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)

self._compare_estimate(computed, expected, **kwargs_compare)

return tensors, expected

def _test_batch(self, functional, tensor, *args, **kwargs):

tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs)

kwargs_compare = {}
if 'atol' in kwargs:
atol = kwargs['atol']
del kwargs['atol']
kwargs_compare['atol'] = atol

if 'rtol' in kwargs:
rtol = kwargs['rtol']
del kwargs['rtol']
kwargs_compare['rtol'] = rtol

# 3-Batch then transform

ind = [3] + [1] * (int(tensors.dim()) - 1)
tensors = tensor.repeat(*ind)

ind = [3] + [1] * (int(expected.dim()) - 1)
expected = expected.repeat(*ind)

torch.random.manual_seed(42)
computed = functional(tensors.clone(), *args, **kwargs)

def test_torchscript_create_fb_matrix(self):

Expand Down
37 changes: 37 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,19 @@ def test_compute_deltas_twochannel(self):
computed = transform(specgram)
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))

def test_batch_MelScale(self):
specgram = torch.randn(2, 31, 2786)

# Single then transform then batch
expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1))

# shape = (3, 2, 201, 1394)
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_compute_deltas(self):
specgram = torch.randn(2, 31, 2786)

Expand Down Expand Up @@ -433,6 +446,30 @@ def test_batch_spectrogram(self):
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_melspectrogram(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1)
Copy link
Contributor

@cpuhrsch cpuhrsch Jan 6, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is repetitive. I'm wondering if there is a way of creating a test generator that explicitly exercises this decorator. It's a function that accepts a function with some args and kwargs and then assumes that the first input is to be batchable (so it applies various reshapes to args[0] etc.)

def _gen_batchable(self, func, *args, **kwargs):
    self.assertEqual(func(*args, **kwargs),
                                func(*(args[0].reshape(3, -1, -1, -1) + args[1:]), **kwargs)

(not sure on the reshape behavior, could also be a lambda).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I followed the format that I did for testing jitability, and simply introduced common functions to test batching. Thoughts?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I'm mostly just trying to iterate on top of that. Might be one of these non-problems I'm trying to fix.


# Batch then transform
computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1))

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected))

def test_batch_mfcc(self):
waveform, sample_rate = torchaudio.load(self.test_filepath)

# Single then transform then batch
expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1)

# Batch then transform
computed = transforms.MFCC()(waveform.repeat(3, 1, 1))

self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
self.assertTrue(torch.allclose(computed, expected, atol=1e-5))

def test_scriptmodule_TimeStretch(self):
n_freq = 400
hop_length = 512
Expand Down
Loading