From e7ff702f90ab46ca7f1ede6af8e32abdf31b141b Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 11 Nov 2019 10:39:28 -0800 Subject: [PATCH 01/15] batching for transforms. --- test/test_functional.py | 20 ++++++++++++ test/test_transforms.py | 34 +++++++++++++++++++ torchaudio/functional.py | 70 +++++++++++++++++++++++++++++++--------- 3 files changed, 109 insertions(+), 15 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index be0804b6c6..afd6a31b75 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase): number_of_trials = 100 specgram = torch.tensor([1., 2., 3., 4.]) + test_dirpath, test_dir = common_utils.create_temp_assets_dir() + test_filepath = os.path.join(test_dirpath, 'assets', + 'steam-train-whistle-daniel_simon.mp3') + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) @@ -46,6 +50,22 @@ def test_compute_deltas_randn(self): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def _test_batch(self, transform): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = transform(waveform).unsqueeze(0).repeat(3,1,1,1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3,1,1) + computed = transform(waveform) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + def test_batch_spectrogram(self): + self._test_batch(F.spectrogram) + def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original sound = sound[..., :estimate.size(-1)] diff --git a/test/test_transforms.py b/test/test_transforms.py index d78c5b6000..7056e18e2b 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -313,6 +313,40 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def _test_batch(self, Transform): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + expected = Transform()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = Transform()(waveform) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + def test_batch(self): + self._test_batch(Spectrogram) + self._test_batch(MuLawEncoding) + self._test_batch(ComputeDeltas) + + def test_batch_mulawdecoding(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + specgram = MuLawEncoding()(specgram) + expected = MuLawEncoding()(specgram).unsqueeze(0).repeat(3, 1, 1, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1, 1) + computed = MuLawEncoding()(waveform) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + if __name__ == '__main__': unittest.main() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 7539f371d4..c4add4324c 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -115,16 +115,19 @@ def istft( original signal length). (Default: whole signal) Returns: - torch.Tensor: Least squares estimation of the original signal of size - (channel, signal_length) or (signal_length) + torch.Tensor: Least squares estimation of the original signal of size (*, signal_length) """ stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) + assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) if stft_matrix_dim == 3: # add a channel dimension stft_matrix = stft_matrix.unsqueeze(0) + # pack batch + shape = stft_matrix.size() + stft_matrix = stft_matrix.reshape(-1, *shape[-3:]) + dtype = stft_matrix.dtype device = stft_matrix.device fft_size = stft_matrix.size(1) @@ -211,6 +214,10 @@ def istft( if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) + + # unpack batch + y = y.reshape(shape[:-2] + y.shape[-3:]) + return y @@ -505,14 +512,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): dtype=complex_specgrams.dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[:, :, :1]) + phase_0 = angle(complex_specgrams[..., :1]) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] - complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] + complex_specgrams_0 = complex_specgrams[..., time_steps.long()] + complex_specgrams_1 = complex_specgrams[..., (time_steps + 1).long()] angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) @@ -525,7 +532,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): # Compute Phase Accum phase = phase + phase_advance - phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1) + phase = torch.cat([phase_0, phase[..., :-1]], dim=-1) phase_acc = torch.cumsum(phase, -1) mag = alphas * norm_1 + (1 - alphas) * norm_0 @@ -545,7 +552,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): Performs an IIR filter by evaluating difference equation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1. + waveform (torch.Tensor): audio waveform of dimension of `(*, channel, time)`. Must be normalized to -1 to 1. a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. Must be same size as b_coeffs (pad with 0's as necessary). @@ -554,10 +561,16 @@ def lfilter(waveform, a_coeffs, b_coeffs): Must be same size as a_coeffs (pad with 0's as necessary). Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1. + output_waveform (torch.Tensor): Dimension of `(*, channel, time)`. Output will be clipped to -1 to 1. """ + dim = waveform.dim() + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + assert(a_coeffs.size(0) == b_coeffs.size(0)) assert(len(waveform.size()) == 2) assert(waveform.device == a_coeffs.device) @@ -597,7 +610,14 @@ def lfilter(waveform, a_coeffs, b_coeffs): padded_output_waveform[:, i_sample + n_order - 1] = o0 - return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) + output = torch.min( + ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]) + ) + + # unpack batch + output = output.reshape(shape[:-1] + output.shape[-1:]) + + return output @torch.jit.script @@ -808,12 +828,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): :math:`N` is (`win_length`-1)//2. Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + specgram (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) win_length (int): The window length used for computing delta mode (str): Mode parameter passed to padding Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + deltas (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) Example >>> specgram = torch.randn(1, 40, 1000) @@ -821,6 +841,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): >>> delta2 = compute_deltas(delta) """ + dim = waveform.dim() + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-2], shape[-1]) + assert win_length >= 3 assert specgram.dim() == 3 assert not specgram.shape[1] % specgram.shape[0] @@ -838,10 +864,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): .repeat(specgram.shape[1], specgram.shape[0], 1) ) - return torch.nn.functional.conv1d( + output = torch.nn.functional.conv1d( specgram, kernel, groups=specgram.shape[1] // specgram.shape[0] ) / denom + # unpack batch + output = output.reshape(shape[:-1] + output.shape[-1:]) + + return output + @torch.jit.script def _compute_nccf(waveform, sample_rate, frame_time, freq_low): @@ -973,16 +1004,22 @@ def detect_pitch_frequency( It is implemented using normalized cross-correlation function and median smoothing. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + waveform (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) sample_rate (int): The sample rate of the waveform (Hz) win_length (int): The window length for median smoothing (in number of frames) freq_low (int): Lowest frequency that can be detected (Hz) freq_high (int): Highest frequency that can be detected (Hz) Returns: - freq (torch.Tensor): Tensor of audio of dimension (channel, frame) + freq (torch.Tensor): Tensor of audio of dimension (*, channel, frame) """ + dim = waveform.dim() + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-2], shape[-1]) + nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) indices = _find_max_per_frame(nccf, sample_rate, freq_high) indices = _median_smoothing(indices, win_length) @@ -991,4 +1028,7 @@ def detect_pitch_frequency( EPSILON = 10 ** (-9) freq = sample_rate / (EPSILON + indices.to(torch.float)) + # unpack batch + freq = freq.reshape(shape[:-2] + freq.shape[-1:]) + return freq From ace097dcb33a56ef80dcd9a047a0f8278c99ff3d Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 12 Nov 2019 17:49:48 -0800 Subject: [PATCH 02/15] test for batching. --- test/test_functional.py | 41 +++++++++++++++++++++++++++++++--------- torchaudio/functional.py | 26 ++++++++++++------------- 2 files changed, 45 insertions(+), 22 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index afd6a31b75..10b576ad84 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -50,22 +50,20 @@ def test_compute_deltas_randn(self): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) - def _test_batch(self, transform): + def test_batch_pitch(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # Single then transform then batch - expected = transform(waveform).unsqueeze(0).repeat(3,1,1,1) + expected = F.detect_pitch_frequency(waveform, sample_rate) + expected = expected.unsqueeze(0).repeat(3,1,1) # Batch then transform waveform = waveform.unsqueeze(0).repeat(3,1,1) - computed = transform(waveform) + computed = F.detect_pitch_frequency(waveform, sample_rate) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) - def test_batch_spectrogram(self): - self._test_batch(F.spectrogram) - def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original sound = sound[..., :estimate.size(-1)] @@ -78,6 +76,8 @@ def _test_istft_is_inverse_of_stft(self, kwargs): # operation to check whether we can reconstruct signal for data_size in self.data_sizes: for i in range(self.number_of_trials): + + # Non-batch sound = common_utils.random_float_tensor(i, data_size) stft = torch.stft(sound, **kwargs) @@ -85,6 +85,14 @@ def _test_istft_is_inverse_of_stft(self, kwargs): self._compare_estimate(sound, estimate) + # Batch + stft = torch.stft(sound, **kwargs) + stft = stft.repeat(3, 1, 1, 1, 1) + sound = sound.repeat(3, 1, 1) + + estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) + self._compare_estimate(sound, estimate) + def test_istft_is_inverse_of_stft1(self): # hann_window, centered, normalized, onesided kwargs1 = { @@ -346,15 +354,30 @@ def test_pitch(self): for filename, freq_ref in tests: waveform, sample_rate = torchaudio.load(filename) - # Convert to stereo for testing purposes - waveform = waveform.repeat(2, 1, 1) - freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) threshold = 1 s = ((freq - freq_ref).abs() > threshold).sum() self.assertFalse(s) + # Convert to stereo and batch for testing purposes + freq = freq.repeat(3, 2, 1, 1) + waveform = waveform.repeat(3, 2, 1, 1) + + freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) + + assert torch.allclose(freq, freq2, atol=1e-5) + + def _test_batch(self, functional): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = functional(waveform) + def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length diff --git a/torchaudio/functional.py b/torchaudio/functional.py index c4add4324c..c9501cf832 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -96,8 +96,7 @@ def istft( Args: stft_matrix (torch.Tensor): Output of stft where each row of a channel is a frequency and each - column is a window. it has a size of either (channel, fft_size, n_frame, 2) or ( - fft_size, n_frame, 2) + column is a window. it has a size of either (..., fft_size, n_frame, 2) n_fft (int): Size of Fourier transform hop_length (Optional[int]): The distance between neighboring sliding window frames. (Default: ``win_length // 4``) @@ -115,10 +114,11 @@ def istft( original signal length). (Default: whole signal) Returns: - torch.Tensor: Least squares estimation of the original signal of size (*, signal_length) + torch.Tensor: Least squares estimation of the original signal of size (..., signal_length) """ stft_matrix_dim = stft_matrix.dim() assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) + assert stft_matrix.nelement() > 0 if stft_matrix_dim == 3: # add a channel dimension @@ -212,12 +212,12 @@ def istft( y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) + # unpack batch + y = y.reshape(shape[:-3] + y.shape[-1:]) + if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) - # unpack batch - y = y.reshape(shape[:-2] + y.shape[-3:]) - return y @@ -518,8 +518,8 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams[..., time_steps.long()] - complex_specgrams_1 = complex_specgrams[..., (time_steps + 1).long()] + complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) @@ -841,11 +841,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): >>> delta2 = compute_deltas(delta) """ - dim = waveform.dim() + dim = specgram.dim() # pack batch - shape = waveform.size() - waveform = waveform.reshape(-1, shape[-2], shape[-1]) + shape = specgram.size() + waveform = specgram.reshape(-1, shape[-2], shape[-1]) assert win_length >= 3 assert specgram.dim() == 3 @@ -1018,7 +1018,7 @@ def detect_pitch_frequency( # pack batch shape = waveform.size() - waveform = waveform.reshape(-1, shape[-2], shape[-1]) + waveform = waveform.reshape([-1] + shape[-1:]) nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) indices = _find_max_per_frame(nccf, sample_rate, freq_high) @@ -1029,6 +1029,6 @@ def detect_pitch_frequency( freq = sample_rate / (EPSILON + indices.to(torch.float)) # unpack batch - freq = freq.reshape(shape[:-2] + freq.shape[-1:]) + freq = freq.reshape(shape[:-1] + freq.shape[-1:]) return freq From 97a22508da1dade943d1273c80da55da3cfd5b59 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Tue, 12 Nov 2019 17:51:19 -0800 Subject: [PATCH 03/15] upate * notation to ... --- torchaudio/functional.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index c9501cf832..453fe3e625 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -438,11 +438,11 @@ def complex_norm(complex_tensor, power=1.0): r"""Compute the norm of complex tensor input. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` power (float): Power of the norm. (Default: `1.0`). Returns: - torch.Tensor: Power of the normed input tensor. Shape of `(*, )` + torch.Tensor: Power of the normed input tensor. Shape of `(..., )` """ if power == 1.0: return torch.norm(complex_tensor, 2, -1) @@ -455,10 +455,10 @@ def angle(complex_tensor): r"""Compute the angle of complex tensor input. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` Return: - torch.Tensor: Angle of a complex tensor. Shape of `(*, )` + torch.Tensor: Angle of a complex tensor. Shape of `(..., )` """ return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) @@ -466,10 +466,10 @@ def angle(complex_tensor): @torch.jit.script def magphase(complex_tensor, power=1.0): # type: (Tensor, float) -> Tuple[Tensor, Tensor] - r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. + r"""Separate a complex-valued spectrogram with shape `(..., 2)` into its magnitude and phase. Args: - complex_tensor (torch.Tensor): Tensor shape of `(*, complex=2)` + complex_tensor (torch.Tensor): Tensor shape of `(..., complex=2)` power (float): Power of the norm. (Default: `1.0`) Returns: @@ -552,7 +552,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): Performs an IIR filter by evaluating difference equation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(*, channel, time)`. Must be normalized to -1 to 1. + waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1. a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. Must be same size as b_coeffs (pad with 0's as necessary). @@ -561,7 +561,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): Must be same size as a_coeffs (pad with 0's as necessary). Returns: - output_waveform (torch.Tensor): Dimension of `(*, channel, time)`. Output will be clipped to -1 to 1. + output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1. """ @@ -828,12 +828,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): :math:`N` is (`win_length`-1)//2. Args: - specgram (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) + specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) win_length (int): The window length used for computing delta mode (str): Mode parameter passed to padding Returns: - deltas (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) + deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) Example >>> specgram = torch.randn(1, 40, 1000) @@ -1004,14 +1004,14 @@ def detect_pitch_frequency( It is implemented using normalized cross-correlation function and median smoothing. Args: - waveform (torch.Tensor): Tensor of audio of dimension (*, channel, freq, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time) sample_rate (int): The sample rate of the waveform (Hz) win_length (int): The window length for median smoothing (in number of frames) freq_low (int): Lowest frequency that can be detected (Hz) freq_high (int): Highest frequency that can be detected (Hz) Returns: - freq (torch.Tensor): Tensor of audio of dimension (*, channel, frame) + freq (torch.Tensor): Tensor of audio of dimension (..., frame) """ dim = waveform.dim() From 35a58b37be6b4a5ec4f69559b059f120125739e7 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 13 Nov 2019 12:31:14 -0800 Subject: [PATCH 04/15] correct shape sent to angle. --- torchaudio/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 453fe3e625..541078ce92 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -512,7 +512,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): dtype=complex_specgrams.dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[..., :1]) + phase_0 = angle(complex_specgrams) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) From 331d6fc19956cb270b16bad8265ff9b11833b1dc Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 13 Nov 2019 15:13:15 -0800 Subject: [PATCH 05/15] correct indexing. --- torchaudio/functional.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 541078ce92..1073da631d 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -512,14 +512,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): dtype=complex_specgrams.dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams) + phase_0 = angle(complex_specgrams[..., :1, :]) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams.index_select(-1, time_steps.long()) - complex_specgrams_1 = complex_specgrams.index_select(-1, (time_steps + 1).long()) + complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) From f66a717f97c823f5920a8a4c5067ef7f557626e4 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 13 Nov 2019 15:28:36 -0800 Subject: [PATCH 06/15] upadte test. --- test/test_transforms.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 7056e18e2b..016741b822 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -328,20 +328,19 @@ def _test_batch(self, Transform): self.assertTrue(torch.allclose(computed, expected)) def test_batch(self): - self._test_batch(Spectrogram) - self._test_batch(MuLawEncoding) - self._test_batch(ComputeDeltas) + self._test_batch(transforms.Spectrogram) + self._test_batch(transforms.ComputeDeltas) def test_batch_mulawdecoding(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 # Single then transform then batch - specgram = MuLawEncoding()(specgram) - expected = MuLawEncoding()(specgram).unsqueeze(0).repeat(3, 1, 1, 1, 1) + specgram = transforms.MuLawEncoding()(specgram) + expected = transforms.MuLawEncoding()(specgram).unsqueeze(0).repeat(3, 1, 1, 1, 1) # Batch then transform waveform = waveform.unsqueeze(0).repeat(3, 1, 1, 1) - computed = MuLawEncoding()(waveform) + computed = transforms.MuLawEncoding()(waveform) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) From 33590548fcffffb373f5b02946553a6715e79b83 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 13 Nov 2019 17:25:08 -0800 Subject: [PATCH 07/15] correct variable. --- test/test_transforms.py | 43 ++++++++++++++++++++++++++++++---------- torchaudio/functional.py | 2 +- 2 files changed, 33 insertions(+), 12 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 016741b822..a8f2008cfd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -313,34 +313,55 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) - def _test_batch(self, Transform): + def test_batch_compute_deltas(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 # Single then transform then batch - expected = Transform()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + expected = transforms.ComputeDeltas()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) # Batch then transform waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = Transform()(waveform) + computed = transforms.ComputeDeltas()(waveform) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) - def test_batch(self): - self._test_batch(transforms.Spectrogram) - self._test_batch(transforms.ComputeDeltas) + def test_batch_spectrogram(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + expected = transforms.Spectrogram()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = transforms.Spectrogram()(waveform) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) - def test_batch_mulawdecoding(self): + def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 # Single then transform then batch - specgram = transforms.MuLawEncoding()(specgram) - expected = transforms.MuLawEncoding()(specgram).unsqueeze(0).repeat(3, 1, 1, 1, 1) + waveform_encoded = transforms.MuLawEncoding()(waveform) + expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1) + + # Batch then transform + waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = transforms.MuLawEncoding()(waveform_batched) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + # Single then transform then batch + waveform_decoded = transforms.MuLawDecoding()(waveform_encoded) + expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1) # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1, 1) - computed = transforms.MuLawEncoding()(waveform) + computed = transforms.MuLawDecoding()(computed) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 1073da631d..2ab2a6f643 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -845,7 +845,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): # pack batch shape = specgram.size() - waveform = specgram.reshape(-1, shape[-2], shape[-1]) + specgram = specgram.reshape(-1, shape[-2], shape[-1]) assert win_length >= 3 assert specgram.dim() == 3 From d92eef66bbf872b0d54fb0adb87240f1978fe2b5 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Wed, 13 Nov 2019 17:39:27 -0800 Subject: [PATCH 08/15] spectrogram test belongs to separate PR. --- test/test_transforms.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index a8f2008cfd..afeb22541e 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -327,20 +327,6 @@ def test_batch_compute_deltas(self): self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) self.assertTrue(torch.allclose(computed, expected)) - def test_batch_spectrogram(self): - waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 - - # Single then transform then batch - expected = transforms.Spectrogram()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) - - # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = transforms.Spectrogram()(waveform) - - # shape = (3, 2, 201, 1394) - self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) - self.assertTrue(torch.allclose(computed, expected)) - def test_batch_mulaw(self): waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 From 3eee1eb02f21f723f70d160ad2d05c5715064e12 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 14 Nov 2019 16:20:21 -0800 Subject: [PATCH 09/15] update compute_deltas to support batch. --- torchaudio/functional.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 2ab2a6f643..a0658ce631 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -841,15 +841,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): >>> delta2 = compute_deltas(delta) """ - dim = specgram.dim() - # pack batch shape = specgram.size() - specgram = specgram.reshape(-1, shape[-2], shape[-1]) + specgram = specgram.reshape(1, -1, shape[-1]) assert win_length >= 3 - assert specgram.dim() == 3 - assert not specgram.shape[1] % specgram.shape[0] n = (win_length - 1) // 2 @@ -861,15 +857,13 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): kernel = ( torch .arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype) - .repeat(specgram.shape[1], specgram.shape[0], 1) + .repeat(specgram.shape[1], 1, 1) ) - output = torch.nn.functional.conv1d( - specgram, kernel, groups=specgram.shape[1] // specgram.shape[0] - ) / denom + output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom # unpack batch - output = output.reshape(shape[:-1] + output.shape[-1:]) + output = output.reshape(shape) return output From 4b2121b8a840b715d3355360a16fa519d77f6004 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 14 Nov 2019 16:48:41 -0800 Subject: [PATCH 10/15] random spectrogram test. --- test/test_transforms.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index afeb22541e..582bd91c89 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -314,14 +314,13 @@ def test_compute_deltas_twochannel(self): self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) def test_batch_compute_deltas(self): - waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + specgram = torch.randn(2, 311, 278746) # Single then transform then batch - expected = transforms.ComputeDeltas()(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1) # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3, 1, 1) - computed = transforms.ComputeDeltas()(waveform) + computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1)) # shape = (3, 2, 201, 1394) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) From 607d37491ee862b05d00f38e8da3e65661f12cd4 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Thu, 14 Nov 2019 17:06:34 -0800 Subject: [PATCH 11/15] lowering memory. --- test/test_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 582bd91c89..690cd2cc18 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -314,7 +314,7 @@ def test_compute_deltas_twochannel(self): self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) def test_batch_compute_deltas(self): - specgram = torch.randn(2, 311, 278746) + specgram = torch.randn(2, 31, 2786) # Single then transform then batch expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1) From 4f75044434f557454e5fe4d12afedd0250b39c83 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 11:25:38 -0500 Subject: [PATCH 12/15] flake8. --- test/test_functional.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_functional.py b/test/test_functional.py index 10b576ad84..90d9c77321 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -55,10 +55,10 @@ def test_batch_pitch(self): # Single then transform then batch expected = F.detect_pitch_frequency(waveform, sample_rate) - expected = expected.unsqueeze(0).repeat(3,1,1) + expected = expected.unsqueeze(0).repeat(3, 1, 1) # Batch then transform - waveform = waveform.unsqueeze(0).repeat(3,1,1) + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) computed = F.detect_pitch_frequency(waveform, sample_rate) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) From 37ae660cb53338b807652c87f10339033e255825 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 14:21:06 -0500 Subject: [PATCH 13/15] update readme. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index eaed8acf36..7b57ecbaf0 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Transforms expect and return the following dimensions. * `MuLawDecode`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time) -Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. +Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use ellipsis "..." as a placeholder for the rest of the dimensions of a tensor. Contributing Guidelines ----------------------- From 47044ade78a0814650b55861434a01ff8ff8a0d1 Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 17:34:42 -0500 Subject: [PATCH 14/15] adding example use. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7b57ecbaf0..7137779cc0 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Transforms expect and return the following dimensions. * `MuLawDecode`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time) -Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use ellipsis "..." as a placeholder for the rest of the dimensions of a tensor. +Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. an optional batching dimension. Contributing Guidelines ----------------------- From 44b4466746ee73b990f2ca7cce6a2e2cfa9c056b Mon Sep 17 00:00:00 2001 From: Vincent Quenneville-Belair Date: Mon, 18 Nov 2019 17:44:34 -0500 Subject: [PATCH 15/15] also mentioning channel. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7137779cc0..65b79e4784 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Transforms expect and return the following dimensions. * `MuLawDecode`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time) -Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. an optional batching dimension. +Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions. Contributing Guidelines -----------------------