diff --git a/README.md b/README.md index eaed8acf36..65b79e4784 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,7 @@ Transforms expect and return the following dimensions. * `MuLawDecode`: (channel, time) -> (channel, time) * `Resample`: (channel, time) -> (channel, time) -Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. +Complex numbers are supported via tensors of dimension (..., 2), and torchaudio provides `complex_norm` and `angle` to convert such a tensor into its magnitude and phase. Here, and in the documentation, we use an ellipsis "..." as a placeholder for the rest of the dimensions of a tensor, e.g. optional batching and channel dimensions. Contributing Guidelines ----------------------- diff --git a/test/test_functional.py b/test/test_functional.py index be0804b6c6..90d9c77321 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -21,6 +21,10 @@ class TestFunctional(unittest.TestCase): number_of_trials = 100 specgram = torch.tensor([1., 2., 3., 4.]) + test_dirpath, test_dir = common_utils.create_temp_assets_dir() + test_filepath = os.path.join(test_dirpath, 'assets', + 'steam-train-whistle-daniel_simon.mp3') + def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) @@ -46,6 +50,20 @@ def test_compute_deltas_randn(self): computed = F.compute_deltas(specgram, win_length=win_length) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_batch_pitch(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) + + # Single then transform then batch + expected = F.detect_pitch_frequency(waveform, sample_rate) + expected = expected.unsqueeze(0).repeat(3, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = F.detect_pitch_frequency(waveform, sample_rate) + + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8): # trim sound for case when constructed signal is shorter than original sound = sound[..., :estimate.size(-1)] @@ -58,6 +76,8 @@ def _test_istft_is_inverse_of_stft(self, kwargs): # operation to check whether we can reconstruct signal for data_size in self.data_sizes: for i in range(self.number_of_trials): + + # Non-batch sound = common_utils.random_float_tensor(i, data_size) stft = torch.stft(sound, **kwargs) @@ -65,6 +85,14 @@ def _test_istft_is_inverse_of_stft(self, kwargs): self._compare_estimate(sound, estimate) + # Batch + stft = torch.stft(sound, **kwargs) + stft = stft.repeat(3, 1, 1, 1, 1) + sound = sound.repeat(3, 1, 1) + + estimate = torchaudio.functional.istft(stft, length=sound.size(1), **kwargs) + self._compare_estimate(sound, estimate) + def test_istft_is_inverse_of_stft1(self): # hann_window, centered, normalized, onesided kwargs1 = { @@ -326,15 +354,30 @@ def test_pitch(self): for filename, freq_ref in tests: waveform, sample_rate = torchaudio.load(filename) - # Convert to stereo for testing purposes - waveform = waveform.repeat(2, 1, 1) - freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) threshold = 1 s = ((freq - freq_ref).abs() > threshold).sum() self.assertFalse(s) + # Convert to stereo and batch for testing purposes + freq = freq.repeat(3, 2, 1, 1) + waveform = waveform.repeat(3, 2, 1, 1) + + freq2 = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate) + + assert torch.allclose(freq, freq2, atol=1e-5) + + def _test_batch(self, functional): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + expected = functional(waveform).unsqueeze(0).repeat(3, 1, 1, 1) + + # Batch then transform + waveform = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = functional(waveform) + def _num_stft_bins(signal_len, fft_len, hop_length, pad): return (signal_len + 2 * pad - fft_len + hop_length) // hop_length diff --git a/test/test_transforms.py b/test/test_transforms.py index a438c4481c..82594bad01 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -313,6 +313,45 @@ def test_compute_deltas_twochannel(self): computed = transform(specgram) self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape)) + def test_batch_compute_deltas(self): + specgram = torch.randn(2, 31, 2786) + + # Single then transform then batch + expected = transforms.ComputeDeltas()(specgram).repeat(3, 1, 1, 1) + + # Batch then transform + computed = transforms.ComputeDeltas()(specgram.repeat(3, 1, 1, 1)) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + def test_batch_mulaw(self): + waveform, sample_rate = torchaudio.load(self.test_filepath) # (2, 278756), 44100 + + # Single then transform then batch + waveform_encoded = transforms.MuLawEncoding()(waveform) + expected = waveform_encoded.unsqueeze(0).repeat(3, 1, 1) + + # Batch then transform + waveform_batched = waveform.unsqueeze(0).repeat(3, 1, 1) + computed = transforms.MuLawEncoding()(waveform_batched) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + + # Single then transform then batch + waveform_decoded = transforms.MuLawDecoding()(waveform_encoded) + expected = waveform_decoded.unsqueeze(0).repeat(3, 1, 1) + + # Batch then transform + computed = transforms.MuLawDecoding()(computed) + + # shape = (3, 2, 201, 1394) + self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape)) + self.assertTrue(torch.allclose(computed, expected)) + def test_batch_spectrogram(self): waveform, sample_rate = torchaudio.load(self.test_filepath) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 4397ddd51b..44b52cf277 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -114,16 +114,20 @@ def istft( original signal length). (Default: whole signal) Returns: - torch.Tensor: Least squares estimation of the original signal of size - (channel, signal_length) or (signal_length) + torch.Tensor: Least squares estimation of the original signal of size (..., signal_length) """ stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) + assert 3 <= stft_matrix_dim, "Incorrect stft dimension: %d" % (stft_matrix_dim) + assert stft_matrix.nelement() > 0 if stft_matrix_dim == 3: # add a channel dimension stft_matrix = stft_matrix.unsqueeze(0) + # pack batch + shape = stft_matrix.size() + stft_matrix = stft_matrix.reshape(-1, *shape[-3:]) + dtype = stft_matrix.dtype device = stft_matrix.device fft_size = stft_matrix.size(1) @@ -208,8 +212,12 @@ def istft( y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) + # unpack batch + y = y.reshape(shape[:-3] + y.shape[-1:]) + if stft_matrix_dim == 3: # remove the channel dimension y = y.squeeze(0) + return y @@ -514,14 +522,14 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): dtype=complex_specgrams.dtype) alphas = time_steps % 1.0 - phase_0 = angle(complex_specgrams[:, :, :1]) + phase_0 = angle(complex_specgrams[..., :1, :]) # Time Padding complex_specgrams = torch.nn.functional.pad(complex_specgrams, [0, 0, 0, 2]) # (new_bins, freq, 2) - complex_specgrams_0 = complex_specgrams[:, :, time_steps.long()] - complex_specgrams_1 = complex_specgrams[:, :, (time_steps + 1).long()] + complex_specgrams_0 = complex_specgrams.index_select(-2, time_steps.long()) + complex_specgrams_1 = complex_specgrams.index_select(-2, (time_steps + 1).long()) angle_0 = angle(complex_specgrams_0) angle_1 = angle(complex_specgrams_1) @@ -534,7 +542,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): # Compute Phase Accum phase = phase + phase_advance - phase = torch.cat([phase_0, phase[:, :, :-1]], dim=-1) + phase = torch.cat([phase_0, phase[..., :-1]], dim=-1) phase_acc = torch.cumsum(phase, -1) mag = alphas * norm_1 + (1 - alphas) * norm_0 @@ -554,7 +562,7 @@ def lfilter(waveform, a_coeffs, b_coeffs): Performs an IIR filter by evaluating difference equation. Args: - waveform (torch.Tensor): audio waveform of dimension of `(channel, time)`. Must be normalized to -1 to 1. + waveform (torch.Tensor): audio waveform of dimension of `(..., time)`. Must be normalized to -1 to 1. a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. Must be same size as b_coeffs (pad with 0's as necessary). @@ -563,10 +571,16 @@ def lfilter(waveform, a_coeffs, b_coeffs): Must be same size as a_coeffs (pad with 0's as necessary). Returns: - output_waveform (torch.Tensor): Dimension of `(channel, time)`. Output will be clipped to -1 to 1. + output_waveform (torch.Tensor): Dimension of `(..., time)`. Output will be clipped to -1 to 1. """ + dim = waveform.dim() + + # pack batch + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + assert(a_coeffs.size(0) == b_coeffs.size(0)) assert(len(waveform.size()) == 2) assert(waveform.device == a_coeffs.device) @@ -606,7 +620,14 @@ def lfilter(waveform, a_coeffs, b_coeffs): padded_output_waveform[:, i_sample + n_order - 1] = o0 - return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) + output = torch.min( + ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):]) + ) + + # unpack batch + output = output.reshape(shape[:-1] + output.shape[-1:]) + + return output @torch.jit.script @@ -817,12 +838,12 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): :math:`N` is (`win_length`-1)//2. Args: - specgram (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + specgram (torch.Tensor): Tensor of audio of dimension (..., freq, time) win_length (int): The window length used for computing delta mode (str): Mode parameter passed to padding Returns: - deltas (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + deltas (torch.Tensor): Tensor of audio of dimension (..., freq, time) Example >>> specgram = torch.randn(1, 40, 1000) @@ -830,9 +851,11 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): >>> delta2 = compute_deltas(delta) """ + # pack batch + shape = specgram.size() + specgram = specgram.reshape(1, -1, shape[-1]) + assert win_length >= 3 - assert specgram.dim() == 3 - assert not specgram.shape[1] % specgram.shape[0] n = (win_length - 1) // 2 @@ -844,12 +867,15 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): kernel = ( torch .arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype) - .repeat(specgram.shape[1], specgram.shape[0], 1) + .repeat(specgram.shape[1], 1, 1) ) - return torch.nn.functional.conv1d( - specgram, kernel, groups=specgram.shape[1] // specgram.shape[0] - ) / denom + output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom + + # unpack batch + output = output.reshape(shape) + + return output @torch.jit.script @@ -982,16 +1008,22 @@ def detect_pitch_frequency( It is implemented using normalized cross-correlation function and median smoothing. Args: - waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time) + waveform (torch.Tensor): Tensor of audio of dimension (..., freq, time) sample_rate (int): The sample rate of the waveform (Hz) win_length (int): The window length for median smoothing (in number of frames) freq_low (int): Lowest frequency that can be detected (Hz) freq_high (int): Highest frequency that can be detected (Hz) Returns: - freq (torch.Tensor): Tensor of audio of dimension (channel, frame) + freq (torch.Tensor): Tensor of audio of dimension (..., frame) """ + dim = waveform.dim() + + # pack batch + shape = waveform.size() + waveform = waveform.reshape([-1] + shape[-1:]) + nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low) indices = _find_max_per_frame(nccf, sample_rate, freq_high) indices = _median_smoothing(indices, win_length) @@ -1000,4 +1032,7 @@ def detect_pitch_frequency( EPSILON = 10 ** (-9) freq = sample_rate / (EPSILON + indices.to(torch.float)) + # unpack batch + freq = freq.reshape(shape[:-1] + freq.shape[-1:]) + return freq