From d1a4d50e588efafdf6119f3b6a87f44a60d2145b Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 2 Jan 2020 14:09:34 -0500 Subject: [PATCH 01/12] extend batch support closes #383 --- test/test_functional.py | 13 +++++++++ test/test_transforms.py | 37 +++++++++++++++++++++++ torchaudio/functional.py | 63 ++++++++++++++++++++++++---------------- torchaudio/transforms.py | 44 +++++++++++++++++++--------- 4 files changed, 119 insertions(+), 38 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 43aa885506..ac02ca6fcd 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -103,6 +103,19 @@ def test_griffinlim(self): self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5)) + # test batch + + # Single then transform then batch + expected = ta_out.unsqueeze(0).repeat(3, 1, 1) + + # Batch then transform + specgram = specgram.unsqueeze(0).repeat(3, 1, 1, 1) + computed = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize, + n_iter, momentum, length, rand_init) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) diff --git a/test/test_transforms.py b/test/test_transforms.py index f2b45ec625..fa84f17b2c 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -374,6 +374,19 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_batch_MelScale(self): + specgram = torch.randn(2, 31, 2786) + + # Single then transform then batch + expected = transforms.MelScale()(specgram).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MelScale()(specgram.repeat(3, 1, 1, 1)) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def test_batch_compute_deltas(self): specgram = torch.randn(2, 31, 2786) @@ -433,6 +446,30 @@ def test_batch_spectrogram(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) + def test_batch_melspectrogram(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transforms.MelSpectrogram()(waveform).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MelSpectrogram()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + def test_batch_mfcc(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transforms.MFCC()(waveform).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.MFCC()(waveform.repeat(3, 1, 1)) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected, atol=1e-5)) + def test_scriptmodule_TimeStretch(self): n_freq = 400 hop_length = 512 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9538d70ed5..f826c813d9 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -99,7 +99,7 @@ def istft( Args: stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. it has a size of either (..., fft_size, n_frame, 2) + column is a window. It has a size of either (..., fft_size, n_frame, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -230,7 +230,7 @@ def spectrogram( The spectrogram can be either magnitude-only or complex. Args: - waveform (torch.Tensor): Tensor of audio of dimension (..., channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) pad (int): Two sided padding of signal window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of FFT @@ -242,8 +242,8 @@ def spectrogram( normalized (bool): Whether to normalize by magnitude after stft Returns: - torch.Tensor: Dimension (..., channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of + torch.Tensor: Dimension (..., freq, time), freq is + ``n_fft // 2 + 1`` and ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ @@ -292,7 +292,7 @@ def griffinlim( IEEE Trans. ASSP, vol.32, no.2, pp.236–243, Apr. 1984. Args: - specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (channel, freq, frames) + specgram (torch.Tensor): A magnitude-only STFT spectrogram of dimension (..., freq, frames) where freq is ``n_fft // 2 + 1``. window (torch.Tensor): Window tensor that is applied/multiplied to each frame/window n_fft (int): Size of FFT, creates ``n_fft // 2 + 1`` bins @@ -310,11 +310,15 @@ def griffinlim( rand_init (bool): Initializes phase randomly if True, to zero otherwise. (Default: ``True``) Returns: - torch.Tensor: waveform of (channel, time), where time equals the ``length`` parameter if given. + torch.Tensor: waveform of (..., time), where time equals the ``length`` parameter if given. """ assert momentum < 1, 'momentum=%s > 1 can be unstable' % momentum assert momentum > 0, 'momentum=%s < 0' % momentum + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = specgram.pow(1 / power) # randomly initialize the phase @@ -351,12 +355,17 @@ def griffinlim( angles = angles.div_(complex_norm(angles).add_(1e-16).unsqueeze(-1).expand_as(angles)) # Return the final phase estimates - return istft(specgram * angles, - n_fft=n_fft, - hop_length=hop_length, - win_length=win_length, - window=window, - length=length) + waveform = istft(specgram * angles, + n_fft=n_fft, + hop_length=hop_length, + win_length=win_length, + window=window, + length=length) + + # unpack batch + waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:]) + + return waveform def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): @@ -699,7 +708,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): https://en.wikipedia.org/wiki/Digital_biquad_filter Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` b0 (float): numerator coefficient of current input, x[n] b1 (float): numerator coefficient of input one time step ago x[n-1] b2 (float): numerator coefficient of input two time steps ago x[n-2] @@ -708,7 +717,7 @@ def biquad(waveform, b0, b1, b2, a0, a1, a2): a2 (float): denominator coefficient of current output y[n-2] Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ device = waveform.device @@ -732,13 +741,13 @@ def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): r"""Design biquad highpass filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) cutoff_freq (float): filter cutoff frequency Q (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ GAIN = 1. @@ -761,13 +770,13 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) cutoff_freq (float): filter cutoff frequency Q (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ GAIN = 1. @@ -790,14 +799,14 @@ def equalizer_biquad(waveform, sample_rate, center_freq, gain, Q=0.707): r"""Design biquad peaking equalizer filter and perform filtering. Similar to SoX implementation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)` + waveform (torch.Tensor): audio waveform of dimension of `(..., time)` sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) center_freq (float): filter's central frequency gain (float): desired gain at the boost (or attenuation) in dB q_factor (float): https://en.wikipedia.org/wiki/Q_factor Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)` + output_waveform (torch.Tensor): Dimension of `(..., time)` """ w0 = 2 * math.pi * center_freq / sample_rate A = math.exp(gain / 40.0 * math.log(10)) @@ -856,13 +865,13 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgram (Tensor): Real spectrogram (channel, freq, time) + specgram (Tensor): Real spectrogram (..., freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (..., freq, time) """ # pack batch @@ -946,7 +955,7 @@ def gain(waveform, gain_db=1.0): r"""Apply amplification or attenuation to the whole waveform. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time). + waveform (torch.Tensor): Tensor of audio of dimension (..., time). gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`). Returns: @@ -999,7 +1008,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): The relationship of probabilities of results follows a bell-shaped, or Gaussian curve, typical of dither generated by analog sources. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) probability_density_function (string): The density function of a continuous random variable (Default: `TPDF`) Options: Triangular Probability Density Function - `TPDF` @@ -1008,6 +1017,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): Returns: torch.Tensor: waveform dithered with TPDF """ + + # pack batch shape = waveform.size() waveform = waveform.reshape(-1, shape[-1]) @@ -1047,6 +1058,8 @@ def _apply_probability_distribution(waveform, density_function="TPDF"): quantised_signal_scaled = torch.round(signal_scaled_dis) quantised_signal = quantised_signal_scaled / down_scaling + + # unpack batch return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) @@ -1056,7 +1069,7 @@ def dither(waveform, density_function="TPDF", noise_shaping=False): particular bit-depth by eliminating nonlinear truncation distortion (i.e. adding minimally perceived noise to mask distortion caused by quantization). Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) density_function (string): The density function of a continuous random variable (Default: `TPDF`) Options: Triangular Probability Density Function - `TPDF` diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 92b336cf92..fc89f329f5 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -62,11 +62,11 @@ def __init__(self, n_fft=400, win_length=None, hop_length=None, def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: Dimension (channel, freq, time), where channel - is unchanged, freq is ``n_fft // 2 + 1`` where ``n_fft`` is the number of + torch.Tensor: Dimension (..., freq, time), where freq is + ``n_fft // 2 + 1`` where ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ return F.spectrogram(waveform, self.pad, self.window, self.n_fft, self.hop_length, @@ -207,11 +207,16 @@ def __init__(self, n_mels=128, sample_rate=16000, f_min=0., f_max=None, n_stft=N def forward(self, specgram): r""" Args: - specgram (torch.Tensor): A spectrogram STFT of dimension (channel, freq, time) + specgram (torch.Tensor): A spectrogram STFT of dimension (..., freq, time) Returns: - torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) + torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) """ + + # pack batch + shape = specgram.size() + specgram = specgram.reshape(-1, shape[-2], shape[-1]) + if self.fb.numel() == 0: tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate) # Attributes cannot be reassigned outside __init__ so workaround @@ -221,6 +226,10 @@ def forward(self, specgram): # (channel, frequency, time).transpose(...) dot (frequency, n_mels) # -> (channel, time, n_mels).transpose(...) mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2) + + # unpack batch + mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:]) + return mel_specgram @@ -273,10 +282,10 @@ def __init__(self, sample_rate=16000, n_fft=400, win_length=None, hop_length=Non def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: Mel frequency spectrogram of size (channel, ``n_mels``, time) + torch.Tensor: Mel frequency spectrogram of size (..., ``n_mels``, time) """ specgram = self.spectrogram(waveform) mel_specgram = self.mel_scale(specgram) @@ -332,11 +341,16 @@ def __init__(self, sample_rate=16000, n_mfcc=40, dct_type=2, norm='ortho', log_m def forward(self, waveform): r""" Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., time) Returns: - torch.Tensor: specgram_mel_db of size (channel, ``n_mfcc``, time) + torch.Tensor: specgram_mel_db of size (..., ``n_mfcc``, time) """ + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + mel_specgram = self.MelSpectrogram(waveform) if self.log_mels: log_offset = 1e-6 @@ -346,6 +360,10 @@ def forward(self, waveform): # (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc) # -> (channel, time, n_mfcc).tranpose(...) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) + + # unpack batch + mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:]) + return mfcc @@ -421,10 +439,10 @@ def __init__(self, orig_freq=16000, new_freq=16000, resampling_method='sinc_inte def forward(self, waveform): r""" Args: - waveform (torch.Tensor): The input signal of dimension (channel, time) + waveform (torch.Tensor): The input signal of dimension (..., time) Returns: - torch.Tensor: Output signal of dimension (channel, time) + torch.Tensor: Output signal of dimension (..., time) """ if self.resampling_method == 'sinc_interpolation': return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq) @@ -471,10 +489,10 @@ def __init__(self, win_length=5, mode="replicate"): def forward(self, specgram): r""" Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) """ return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode) From 0dfb41864abd148543e665f9df71991307b50a00 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 2 Jan 2020 14:59:02 -0500 Subject: [PATCH 02/12] todo for a functional to have batch. --- torchaudio/functional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index f826c813d9..648f9b5585 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -838,6 +838,8 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) """ + # TODO Introduce batch support + if axis != 2 and axis != 3: raise ValueError('Only Frequency and Time masking are supported') From 23d970ef90308431ed0157a27ad5f435053cce17 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 3 Jan 2020 15:06:54 -0500 Subject: [PATCH 03/12] adjust tolerance. --- test/test_functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_functional.py b/test/test_functional.py index ac02ca6fcd..0400c75fab 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -114,7 +114,7 @@ def test_griffinlim(self): n_iter, momentum, length, rand_init) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected)) + self.assertTrue(torch.allclose(computed, expected, atol=5e-5)) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) From 310a8d2613caa6f280c62db48d765b638d4bf6c3 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 9 Jan 2020 16:39:11 -0500 Subject: [PATCH 04/12] batch for mask --- test/test_functional.py | 20 ++++++++++++++++++++ torchaudio/functional.py | 24 ++++++++++++++++-------- 2 files changed, 36 insertions(+), 8 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 0400c75fab..187c0e82d2 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -509,6 +509,26 @@ def test_pitch(self): assert torch.allclose(freq, freq2, atol=1e-5) + def test_batch_mask_along_axis_iid(self): + + specgram = torch.randn(2, 5, 5) + mask_param = 2 + mask_value = 30. + axis = 2 + + torch.manual_seed(42) + + # Single then transform then batch + expected = F.mask_along_axis_iid(specgram, mask_param=mask_param, mask_value=mask_value, axis=axis) + expected = expected.unsqueeze(0).unsqueeze(0) + + # Batch then transform + specgrams = specgram.unsqueeze(0).unsqueeze(0) + computed = F.mask_along_axis_iid(specgrams, mask_param=mask_param, mask_value=mask_value, axis=axis) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def _test_batch(self, functional): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 648f9b5585..2aa797874d 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -835,16 +835,21 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) Returns: - torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) + torch.Tensor: Masked spectrograms of dimensions (..., channel, freq, time) """ - # TODO Introduce batch support + # pack batch + shape = specgrams.size() + specgrams = specgrams.reshape([-1] + list(shape[-3:])) if axis != 2 and axis != 3: raise ValueError('Only Frequency and Time masking are supported') - value = torch.rand(specgrams.shape[:2]) * mask_param - min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value) + # Shift so as to start from the end + axis -= 4 + + value = torch.rand(specgrams.shape[:-2]) * mask_param + min_value = torch.rand(specgrams.shape[:-2]) * (specgrams.size(axis) - value) # Create broadcastable mask mask_start = (min_value.long())[..., None, None].float() @@ -856,6 +861,9 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) specgrams = specgrams.transpose(axis, -1) + # unpack batch + specgrams = specgrams.reshape(shape[:-3] + specgrams.shape[-3:]) + return specgrams @@ -873,12 +881,12 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked spectrogram of dimensions (..., freq, time) + torch.Tensor: Masked spectrogram of dimensions (..., channel, freq, time) """ # pack batch shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-2:])) + specgram = specgram.reshape([-1] + list(shape[-3:])) value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -895,9 +903,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): raise ValueError('Only Frequency and Time masking are supported') # unpack batch - specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + specgram = specgram.reshape(shape[:-3] + specgram.shape[-3:]) - return specgram.reshape(shape[:-2] + specgram.shape[-2:]) + return specgram def compute_deltas(specgram, win_length=5, mode="replicate"): From 5713868d2fc013e91f59da1f2c710f256f1e04d5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 10:22:11 -0500 Subject: [PATCH 05/12] function for batch test. --- test/test_functional.py | 106 +++++++++++++++++++++++----------------- 1 file changed, 61 insertions(+), 45 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 187c0e82d2..5e8d5eaa8f 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -103,18 +103,21 @@ def test_griffinlim(self): self.assertTrue(torch.allclose(ta_out, lr_out, atol=5e-5)) - # test batch + def test_batch_griffinlim(self): - # Single then transform then batch - expected = ta_out.unsqueeze(0).repeat(3, 1, 1) + tensor = torch.rand((1, 201, 6)) - # Batch then transform - specgram = specgram.unsqueeze(0).repeat(3, 1, 1, 1) - computed = F.griffinlim(specgram, window, n_fft, hop, ws, 1, normalize, - n_iter, momentum, length, rand_init) + n_fft = 400 + ws = 400 + hop = 200 + window = torch.hann_window(ws) + power = 2 + normalize = False + momentum = 0.99 + n_iter = 32 + length = 1000 - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected, atol=5e-5)) + self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) @@ -139,22 +142,17 @@ def test_compute_deltas_randn(self): win_length = 2 * 7 + 1 specgram = torch.randn(channel, n_mfcc, time) computed = F.compute_deltas(specgram, win_length=win_length) + self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + _test_torchscript_functional(F.compute_deltas, specgram, win_length=win_length) def test_batch_pitch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) + self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) - # Single then transform then batch - expected = F.detect_pitch_frequency(waveform, sample_rate) - expected = expected.unsqueeze(0).repeat(3, 1, 1) - - # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = F.detect_pitch_frequency(waveform, sample_rate) - - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected)) + def test_jit_pitch(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) _test_torchscript_functional(F.detect_pitch_frequency, waveform, sample_rate) def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): @@ -170,7 +168,6 @@ def _test_istft_is_inverse_of_stft(self, kwargs): for data_size in self.data_sizes: for i in range(self.number_of_trials): - # Non-batch sound = common_utils.random_float_tensor(i, data_size) stft = torch.stft(sound, **kwargs) @@ -178,14 +175,6 @@ def _test_istft_is_inverse_of_stft(self, kwargs): self._compare_estimate(sound, estimate) - # Batch - stft = torch.stft(sound, **kwargs) - stft = stft.repeat(3, 1, 1, 1, 1) - sound = sound.repeat(3, 1, 1) - - estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) - self._compare_estimate(sound, estimate) - def test_istft_is_inverse_of_stft1(self): # hann_window, centered, normalized, onesided kwargs1 = { @@ -402,6 +391,16 @@ def test_linearity_of_istft4(self): data_size = (2, 7, 3, 2) self._test_linearity_of_istft(data_size, kwargs4, atol=1e-5, rtol=1e-8) + def test_batch_istft(self): + + stft = torch.tensor([ + [[4., 0.], [4., 0.], [4., 0.], [4., 0.], [4., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]], + [[0., 0.], [0., 0.], [0., 0.], [0., 0.], [0., 0.]] + ]) + + self._test_batch(F.istft, stft, n_fft=4, length=4) + def _test_create_fb(self, n_mels=40, sample_rate=22050, n_fft=2048, fmin=0.0, fmax=8000.0): # Using a decorator here causes parametrize to fail on Python 2 if not IMPORT_LIBROSA: @@ -502,32 +501,49 @@ def test_pitch(self): self.assertFalse(s) # Convert to stereo and batch for testing purposes - freq = freq.repeat(3, 2, 1, 1) - waveform = waveform.repeat(3, 2, 1, 1) + self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) # , atol=1e-5) + + def _test_batch_shape(self, functional, tensor, *args, **kwargs): + + # Single then transform then batch + + expected = functional(tensor, *args, **kwargs) + expected = expected.unsqueeze(0).unsqueeze(0) + + # 1-Batch then transform + + tensors = tensor.unsqueeze(0).unsqueeze(0) + computed = functional(tensors, *args, **kwargs) + + self._compare_estimate(computed, expected) - freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) + return tensors, expected - assert torch.allclose(freq, freq2, atol=1e-5) + def _test_batch(self, functional, tensor, *args, **kwargs): + + tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs) + + # 3-Batch then transform + + ind = [3] + [1] * (int(tensors.dim()) - 1) + tensors = tensor.repeat(*ind) + + ind = [3] + [1] * (int(expected.dim()) - 1) + expected = expected.repeat(*ind) + + computed = functional(tensors, *args, **kwargs) + + self._compare_estimate(computed, expected) def test_batch_mask_along_axis_iid(self): - specgram = torch.randn(2, 5, 5) + tensor = torch.rand(2, 5, 5) + mask_param = 2 mask_value = 30. axis = 2 - torch.manual_seed(42) - - # Single then transform then batch - expected = F.mask_along_axis_iid(specgram, mask_param=mask_param, mask_value=mask_value, axis=axis) - expected = expected.unsqueeze(0).unsqueeze(0) - - # Batch then transform - specgrams = specgram.unsqueeze(0).unsqueeze(0) - computed = F.mask_along_axis_iid(specgrams, mask_param=mask_param, mask_value=mask_value, axis=axis) - - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected)) + self._test_batch_shape(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis) def _test_batch(self, functional): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 From 62cc467d8ee48294d95c50f6f6017f9634d0c3cb Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 10:33:18 -0500 Subject: [PATCH 06/12] clean after rebasing. --- test/test_functional.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 5e8d5eaa8f..2c04eeff83 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -501,7 +501,7 @@ def test_pitch(self): self.assertFalse(s) # Convert to stereo and batch for testing purposes - self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) # , atol=1e-5) + self._test_batch(F.detect_pitch_frequency, waveform, sample_rate) def _test_batch_shape(self, functional, tensor, *args, **kwargs): @@ -545,16 +545,6 @@ def test_batch_mask_along_axis_iid(self): self._test_batch_shape(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis) - def _test_batch(self, functional): - waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 - - # Single then transform then batch - expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1) - - # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = functional(waveform) - def test_torchscript_create_fb_matrix(self): n_stft = 100 From 42c4fb2e360abe16bcea570f8b3339a2f6040d8f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 14:40:32 -0500 Subject: [PATCH 07/12] set seed. --- test/test_functional.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_functional.py b/test/test_functional.py index 2c04eeff83..989f98f4c2 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -105,6 +105,7 @@ def test_griffinlim(self): def test_batch_griffinlim(self): + torch.random.manual_seed(42) tensor = torch.rand((1, 201, 6)) n_fft = 400 From a45e619ad76c98dc1b55ebf52869356be81c2e30 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 16:24:24 -0500 Subject: [PATCH 08/12] adjust tolerance for griffinlim. --- test/test_functional.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 989f98f4c2..057dd2a132 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -118,7 +118,7 @@ def test_batch_griffinlim(self): n_iter = 32 length = 1000 - self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0) + self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) @@ -506,6 +506,18 @@ def test_pitch(self): def _test_batch_shape(self, functional, tensor, *args, **kwargs): + kwargs_compare = {} + if 'atol' in kwargs: + atol = kwargs['atol'] + del kwargs['atol'] + kwargs_compare['atol'] = atol + print(kwargs) + + if 'rtol' in kwargs: + rtol = kwargs['rtol'] + del kwargs['rtol'] + kwargs_compare['rtol'] = rtol + # Single then transform then batch expected = functional(tensor, *args, **kwargs) @@ -516,7 +528,7 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs): tensors = tensor.unsqueeze(0).unsqueeze(0) computed = functional(tensors, *args, **kwargs) - self._compare_estimate(computed, expected) + self._compare_estimate(computed, expected, **kwargs_compare) return tensors, expected @@ -524,6 +536,17 @@ def _test_batch(self, functional, tensor, *args, **kwargs): tensors, expected = self._test_batch_shape(functional, tensor, *args, **kwargs) + kwargs_compare = {} + if 'atol' in kwargs: + atol = kwargs['atol'] + del kwargs['atol'] + kwargs_compare['atol'] = atol + + if 'rtol' in kwargs: + rtol = kwargs['rtol'] + del kwargs['rtol'] + kwargs_compare['rtol'] = rtol + # 3-Batch then transform ind = [3] + [1] * (int(tensors.dim()) - 1) @@ -534,7 +557,7 @@ def _test_batch(self, functional, tensor, *args, **kwargs): computed = functional(tensors, *args, **kwargs) - self._compare_estimate(computed, expected) + self._compare_estimate(computed, expected, **kwargs_compare) def test_batch_mask_along_axis_iid(self): From 66f0023f1432f1934b70afbb4b92b2f06e55724f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 17:03:43 -0500 Subject: [PATCH 09/12] attempt at batch for mask. --- test/test_functional.py | 23 ++++++++++++++++++----- torchaudio/functional.py | 2 +- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 057dd2a132..abb022f04c 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -511,7 +511,6 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs): atol = kwargs['atol'] del kwargs['atol'] kwargs_compare['atol'] = atol - print(kwargs) if 'rtol' in kwargs: rtol = kwargs['rtol'] @@ -520,13 +519,16 @@ def _test_batch_shape(self, functional, tensor, *args, **kwargs): # Single then transform then batch - expected = functional(tensor, *args, **kwargs) + torch.random.manual_seed(42) + expected = functional(tensor.clone(), *args, **kwargs) expected = expected.unsqueeze(0).unsqueeze(0) # 1-Batch then transform tensors = tensor.unsqueeze(0).unsqueeze(0) - computed = functional(tensors, *args, **kwargs) + + torch.random.manual_seed(42) + computed = functional(tensors.clone(), *args, **kwargs) self._compare_estimate(computed, expected, **kwargs_compare) @@ -555,19 +557,30 @@ def _test_batch(self, functional, tensor, *args, **kwargs): ind = [3] + [1] * (int(expected.dim()) - 1) expected = expected.repeat(*ind) - computed = functional(tensors, *args, **kwargs) + torch.random.manual_seed(42) + computed = functional(tensors.clone(), *args, **kwargs) self._compare_estimate(computed, expected, **kwargs_compare) def test_batch_mask_along_axis_iid(self): + mask_param = 2 + mask_value = 30. + axis = 2 + + tensor = torch.rand(2, 5, 5) + + self._test_batch(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis, atol=1e-1, rtol=1e-1) + + def test_batch_mask_along_axis(self): + tensor = torch.rand(2, 5, 5) mask_param = 2 mask_value = 30. axis = 2 - self._test_batch_shape(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis) + self._test_batch(F.mask_along_axis, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis) def test_torchscript_create_fb_matrix(self): diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 2aa797874d..9037bb7b53 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -875,7 +875,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgram (Tensor): Real spectrogram (..., freq, time) + specgram (Tensor): Real spectrogram (..., channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) From bcdefb7f4017ba17b83e3c98f759da955232405f Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 17:07:38 -0500 Subject: [PATCH 10/12] remove batch for mask. --- test/test_functional.py | 22 ---------------------- torchaudio/functional.py | 27 +++++---------------------- 2 files changed, 5 insertions(+), 44 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index abb022f04c..0117ed9de4 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -560,28 +560,6 @@ def _test_batch(self, functional, tensor, *args, **kwargs): torch.random.manual_seed(42) computed = functional(tensors.clone(), *args, **kwargs) - self._compare_estimate(computed, expected, **kwargs_compare) - - def test_batch_mask_along_axis_iid(self): - - mask_param = 2 - mask_value = 30. - axis = 2 - - tensor = torch.rand(2, 5, 5) - - self._test_batch(F.mask_along_axis_iid, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis, atol=1e-1, rtol=1e-1) - - def test_batch_mask_along_axis(self): - - tensor = torch.rand(2, 5, 5) - - mask_param = 2 - mask_value = 30. - axis = 2 - - self._test_batch(F.mask_along_axis, tensor, mask_param=mask_param, mask_value=mask_value, axis=axis) - def test_torchscript_create_fb_matrix(self): n_stft = 100 diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 9037bb7b53..b727062e1b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -835,21 +835,14 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): axis (int): Axis to apply masking on (2 -> frequency, 3 -> time) Returns: - torch.Tensor: Masked spectrograms of dimensions (..., channel, freq, time) + torch.Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) """ - # pack batch - shape = specgrams.size() - specgrams = specgrams.reshape([-1] + list(shape[-3:])) - if axis != 2 and axis != 3: raise ValueError('Only Frequency and Time masking are supported') - # Shift so as to start from the end - axis -= 4 - - value = torch.rand(specgrams.shape[:-2]) * mask_param - min_value = torch.rand(specgrams.shape[:-2]) * (specgrams.size(axis) - value) + value = torch.rand(specgrams.shape[:2]) * mask_param + min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value) # Create broadcastable mask mask_start = (min_value.long())[..., None, None].float() @@ -861,9 +854,6 @@ def mask_along_axis_iid(specgrams, mask_param, mask_value, axis): specgrams.masked_fill_((mask >= mask_start) & (mask < mask_end), mask_value) specgrams = specgrams.transpose(axis, -1) - # unpack batch - specgrams = specgrams.reshape(shape[:-3] + specgrams.shape[-3:]) - return specgrams @@ -875,19 +865,15 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): All examples will have the same mask interval. Args: - specgram (Tensor): Real spectrogram (..., channel, freq, time) + specgram (Tensor): Real spectrogram (channel, freq, time) mask_param (int): Number of columns to be masked will be uniformly sampled from [0, mask_param] mask_value (float): Value to assign to the masked columns axis (int): Axis to apply masking on (1 -> frequency, 2 -> time) Returns: - torch.Tensor: Masked spectrogram of dimensions (..., channel, freq, time) + torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) """ - # pack batch - shape = specgram.size() - specgram = specgram.reshape([-1] + list(shape[-3:])) - value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -902,9 +888,6 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): else: raise ValueError('Only Frequency and Time masking are supported') - # unpack batch - specgram = specgram.reshape(shape[:-3] + specgram.shape[-3:]) - return specgram From b3525b6b313a947203102995fe72b21a4bec3027 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Fri, 10 Jan 2020 17:27:14 -0500 Subject: [PATCH 11/12] flake8. --- test/test_functional.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/test_functional.py b/test/test_functional.py index 0117ed9de4..6afc14302b 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -118,7 +118,9 @@ def test_batch_griffinlim(self): n_iter = 32 length = 1000 - self._test_batch(F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5) + self._test_batch( + F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0, atol=5e-5 + ) def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) From f5eba11455f6aea33e71beac50bc35237e74b419 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 13 Jan 2020 14:42:53 -0500 Subject: [PATCH 12/12] revert remove batch here. --- torchaudio/functional.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index b727062e1b..97f56b8bc3 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -874,6 +874,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) """ + # pack batch + shape = specgram.size() + specgram = specgram.reshape([-1] + list(shape[-2:])) + value = torch.rand(1) * mask_param min_value = torch.rand(1) * (specgram.size(axis) - value) @@ -888,6 +892,9 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): else: raise ValueError('Only Frequency and Time masking are supported') + # unpack batch + specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:]) + return specgram