From 65c939a7a190f55e7e9a555cbd50fcc94f2e5544 Mon Sep 17 00:00:00 2001 From: Vincent QB Date: Mon, 13 Jan 2020 15:38:39 -0500 Subject: [PATCH] extend batch support (#391) * extend batch support closes #383 * function for batch test. * set seed. --- test/test_functional.py | 95 ++++++++++++++++++++++++++++------------ test/test_transforms.py | 37 ++++++++++++++++ torchaudio/functional.py | 36 ++++++++------- torchaudio/transforms.py | 44 +++++++++++++------ 4 files changed, 154 insertions(+), 58 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 0867600489..059f9bbf9b 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -75,22 +75,17 @@ def test_compute_deltas_randn(self): win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) computed = F.compute_deltas(specgram, win_length=win_length) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) def test_batch_pitch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) + self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) - # Single then transform then batch - expected = F.detect_pitch_frequency(waveform, sample_rate) - expected = expected.unsqueeze(0).repeat(3, 1, 1) - - # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = F.detect_pitch_frequency(waveform, sample_rate) - - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected)) + def test_jit_pitch(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): @@ -106,7 +101,6 @@ def _test_istft_is_inverse_of_stft(self, kwargs): for data_size in self.data_sizes: for i in range(self.number_of_trials): - # Non-batch sound = common_utils.random_float_tensor(i, data_size) stft = torch.stft(sound, **kwargs) @@ -114,14 +108,6 @@ def _test_istft_is_inverse_of_stft(self, kwargs): self._compare_estimate(sound, estimate) - # Batch - stft = torch.stft(sound, **kwargs) - stft = stft.repeat(3, 1, 1, 1, 1) - sound = sound.repeat(3, 1, 1) - - estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) - self._compare_estimate(sound, estimate) - def test_istft_is_inverse_of_stft1(self): # hann_window, centered, normalized, onesided kwargs1 = { @@ -338,6 +324,16 @@ def test_linearity_of_istft4(self): data_size = (2, 7, 3, 2) self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) + def test_batch_istft(self): + + stft = torch.tensor([ + [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] + ]) + + self._test_batch(F.istft, stft, n_fft=4, length=4) + def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0): # Using a decorator here causes parametrize to fail on Python 2 if not IMPORT_LIBROSA: @@ -438,22 +434,63 @@ def test_pitch(self): self.assertFalse(s) # Convert to stereo and batch for testing purposes - freq = freq.repeat(3, 2, 1, 1) - waveform = waveform.repeat(3, 2, 1, 1) + self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) - freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) + def _test_batch_shape(self, functional, tensor, *args, **kwargs): - assert torch.allclose(freq, freq2, atol=1e-5) + kwargs_compare = {} + if 'atol' in kwargs: + atol = kwargs['atol'] + del kwargs['atol'] + kwargs_compare['atol'] = atol - def _test_batch(self, functional): - waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + if 'rtol' in kwargs: + rtol = kwargs['rtol'] + del kwargs['rtol'] + kwargs_compare['rtol'] = rtol # Single then transform then batch - expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1) - # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = functional(waveform) + torch.random.manual_seed(42) + expected = functional(tensor.clone(), *args, **kwargs) + expected = expected.unsqueeze(0).unsqueeze(0) + + # 1-Batch then transform + + tensors = tensor.unsqueeze(0).unsqueeze(0) + + torch.random.manual_seed(42) + computed = functional(tensors.clone(), *args, **kwargs) + + self._compare_estimate(computed, expected, **kwargs_compare) + + return tensors, expected + + def _test_batch(self, functional, tensor, *args, **kwargs): + + tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs) + + kwargs_compare = {} + if 'atol' in kwargs: + atol = kwargs['atol'] + del kwargs['atol'] + kwargs_compare['atol'] = atol + + if 'rtol' in kwargs: + rtol = kwargs['rtol'] + del kwargs['rtol'] + kwargs_compare['rtol'] = rtol + + # 3-Batch then transform + + ind = [3] + [1] * (int(tensors.dim()) - 1) + tensors = tensor.repeat(*ind) + + ind = [3] + [1] * (int(expected.dim()) - 1) + expected = expected.repeat(*ind) + + torch.random.manual_seed(42) + computed = functional(tensors.clone(), *args, **kwargs) def _num_stft_bins(signal_len, fft_len, hop_length, pad): diff --git a/test/test_transforms.py b/test/test_transforms.py index 4defe29911..a278263c56 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -363,6 +363,19 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_batch_MelScale(self): + specgram = torch.randn(2, 31, 2786) + + # Single then transform then batch + expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def test_batch_compute_deltas(self): specgram = torch.randn(2, 31, 2786) @@ -422,6 +435,30 @@ def test_batch_spectrogram(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_batch_melspectrogram(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + def test_batch_mfcc(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MFCC()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) + def test_scriptmodule_TimeStretch(self): n_freq = 400 hop_length = 512 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 94c509dbd7..809be33563 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -96,7 +96,7 @@ def istft( Args: stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. it has a size of either (..., fft_size, n_frame, 2) + column is a window. It has a size of either (..., fft_size, n_frame, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -229,7 +229,7 @@ def spectrogram( The spectrogram can be either magnitude-only or complex. Args: - waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) pad (int): Two sided padding of signal window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of FFT @@ -241,8 +241,8 @@ def spectrogram( normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Dimension (..., channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of + torch.Tensor: Dimension (..., freq, time), freq is + ``n_fft // 2 + 1`` and ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ @@ -613,7 +613,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): https://en.wikipedia.org/wiki/Digital_biquad_filter Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` b0 (float): numerator coefficient of current input, x[n] b1 (float): numerator coefficient of input one time step ago x[n-1] b2 (float): numerator coefficient of input two time steps ago x[n-2] @@ -622,7 +622,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): a2 (float): denominator coefficient of current output y[n-2] Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ device = waveform.device @@ -646,13 +646,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) cutoff_freq (float): filter cutoff frequency Q (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ GAIN = 1. @@ -675,13 +675,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) cutoff_freq (float): filter cutoff frequency Q (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ GAIN = 1. @@ -704,14 +704,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) center_freq (float): filter's central frequency gain (float): desired gain at the boost (or attenuation) in dB q_factor (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ w0 = 2 * math.pi * center_freq / sample_rate A = math.exp(gain / 40.0 * math.log(10)) @@ -800,7 +800,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): # unpack batch specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) - return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + return specgram def compute_deltas(specgram, win_length=5, mode="replicate"): @@ -860,7 +860,7 @@ def gain(waveform, gain_db=1.0): r"""Apply amplification or attenuation to the whole waveform. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time). + waveform (torch.Tensor): Tensor of audio of dimension (..., time). gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`). Returns: @@ -913,7 +913,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): The relationship of probabilities of results follows a bell-shaped, or Gaussian curve, typical of dither generated by analog sources. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) probability_density_function (string): The density function of a continuous random variable (Default: `TPDF`) Options: Triangular Probability Density Function - `TPDF` @@ -922,6 +922,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): Returns: torch.Tensor: waveform dithered with TPDF """ + + # pack batch shape = waveform.size() waveform = waveform.reshape(-1, shape[-1]) @@ -961,6 +963,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): quantised_signal_scaled = torch.round(signal_scaled_dis) quantised_signal = quantised_signal_scaled / down_scaling + + # unpack batch return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) @@ -970,7 +974,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False): particular bit-depth by eliminating nonlinear truncation distortion (i.e. adding minimally perceived noise to mask distortion caused by quantization). Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) density_function (string): The density function of a continuous random variable (Default: `TPDF`) Options: Triangular Probability Density Function - `TPDF` diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index b5cb054688..db9cfdcb2b 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -59,11 +59,11 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + torch.Tensor: Dimension (..., freq, time), where freq is + ``n_fft // 2 + 1`` where ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, @@ -141,11 +141,16 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N def forward(self, specgram): r""" Args: - specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time) + specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time) Returns: - torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) + torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) """ + + # pack batch + shape = specgram.size() + specgram = specgram.reshape(-1, shape[-2], shape[-1]) + if self.fb.numel() == 0: tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) # Attributes cannot be reassigned outside __init__ so workaround @@ -155,6 +160,10 @@ def forward(self, specgram): # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) + + # unpack batch + mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) + return mel_specgram @@ -207,10 +216,10 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) + torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) """ specgram = self.spectrogram(waveform) mel_specgram = self.mel_scale(specgram) @@ -266,11 +275,16 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time) + torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time) """ + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 @@ -280,6 +294,10 @@ def forward(self, waveform): # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) # -> (channel, time, n_mfcc).tranpose(...) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:]) + return mfcc @@ -355,10 +373,10 @@ def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_inte def forward(self, waveform): r""" Args: - waveform (torch.Tensor): The input signal of dimension (channel, time) + waveform (torch.Tensor): The input signal of dimension (..., time) Returns: - torch.Tensor: Output signal of dimension (channel, time) + torch.Tensor: Output signal of dimension (..., time) """ if self.resampling_method == 'sinc_interpolation': return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) @@ -405,10 +423,10 @@ def __init__(self, win_length=5, mode="replicate"): def forward(self, specgram): r""" Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) """ return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)