From ed46dca991b8fd2f2c5bd08d564fa09483ec846c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 1 Nov 2021 22:30:39 -0400 Subject: [PATCH 1/2] [BC-Breaking] Remove pseudo complex type support from phase_vocoder / TimeStretch --- .../functional/functional_impl.py | 11 ++---- .../librosa_compatibility_test_impl.py | 15 ++----- .../torchscript_consistency_impl.py | 6 +-- .../transforms/autograd_test_impl.py | 9 +---- .../transforms/batch_consistency_test.py | 6 +-- .../torchscript_consistency_impl.py | 7 ++-- torchaudio/functional/functional.py | 39 +------------------ torchaudio/transforms.py | 3 +- 8 files changed, 18 insertions(+), 78 deletions(-) diff --git a/test/torchaudio_unittest/functional/functional_impl.py b/test/torchaudio_unittest/functional/functional_impl.py index b355293235..541b660679 100644 --- a/test/torchaudio_unittest/functional/functional_impl.py +++ b/test/torchaudio_unittest/functional/functional_impl.py @@ -429,11 +429,8 @@ def test_resample_waveform_downsample_accuracy(self, resampling_method, i): def test_resample_waveform_upsample_accuracy(self, resampling_method, i): self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method) - @nested_params( - [0.5, 1.01, 1.3], - [True, False], - ) - def test_phase_vocoder_shape(self, rate, test_pseudo_complex): + @nested_params([0.5, 1.01, 1.3]) + def test_phase_vocoder_shape(self, rate): """Verify the output shape of phase vocoder""" hop_length = 256 num_freq = 1025 @@ -443,8 +440,6 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex): torch.random.manual_seed(42) spec = torch.randn( batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device) - if test_pseudo_complex: - spec = torch.view_as_real(spec) phase_advance = torch.linspace( 0, @@ -456,7 +451,7 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex): assert spec.dim() == spec_stretch.dim() expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))]) - output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape + output_shape = spec_stretch.shape assert output_shape == expected_shape @parameterized.expand( diff --git a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py index a63f0da9d4..15d696a0d2 100644 --- a/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py +++ b/test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py @@ -126,11 +126,8 @@ def test_amplitude_to_DB(self): @unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available") class FunctionalComplex(TestBaseMixin): - @nested_params( - [0.5, 1.01, 1.3], - [True, False], - ) - def test_phase_vocoder(self, rate, test_pseudo_complex): + @nested_params([0.5, 1.01, 1.3]) + def test_phase_vocoder(self, rate): hop_length = 256 num_freq = 1025 num_frames = 400 @@ -147,15 +144,11 @@ def test_phase_vocoder(self, rate, test_pseudo_complex): device=self.device, dtype=torch.float64)[..., None] - stretched = F.phase_vocoder( - torch.view_as_real(spec) if test_pseudo_complex else spec, - rate=rate, phase_advance=phase_advance) + stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance) expected_stretched = librosa.phase_vocoder( spec.cpu().numpy(), rate=rate, hop_length=hop_length) - self.assertEqual( - torch.view_as_complex(stretched) if test_pseudo_complex else stretched, - torch.from_numpy(expected_stretched)) + self.assertEqual(stretched, torch.from_numpy(expected_stretched)) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 7925372584..5a23f76912 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -3,7 +3,6 @@ import torch import torchaudio.functional as F -from parameterized import parameterized from torchaudio_unittest import common_utils from torchaudio_unittest.common_utils import ( @@ -641,8 +640,7 @@ def func_beta(tensor): self._assert_consistency(func, tensor) self._assert_consistency(func_beta, tensor) - @parameterized.expand([(True, ), (False, )]) - def test_phase_vocoder(self, test_paseudo_complex): + def test_phase_vocoder(self): def func(tensor): is_complex = tensor.is_complex() @@ -659,7 +657,7 @@ def func(tensor): return F.phase_vocoder(tensor, rate, phase_advance) tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2)) - self._assert_consistency_complex(func, tensor, test_paseudo_complex) + self._assert_consistency_complex(func, tensor) class FunctionalFloat32Only(TestBaseMixin): diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 836f27790b..669165c18f 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -226,11 +226,8 @@ def test_timestretch_zeros_fail(self): spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None) self.assert_grad(transform, [spectrogram]) - @nested_params( - [0.7, 0.8, 0.9, 1.0, 1.3], - [False, True], - ) - def test_timestretch_non_zero(self, rate, test_pseudo_complex): + @nested_params([0.7, 0.8, 0.9, 1.0, 1.3]) + def test_timestretch_non_zero(self, rate): """Verify that ``T.TimeStretch`` does not fail if it's not close to 0 ``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability @@ -254,8 +251,6 @@ def test_timestretch_non_zero(self, rate, test_pseudo_complex): epsilon = 1e-2 too_close = spectrogram.abs() < epsilon spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs() - if test_pseudo_complex: - spectrogram = torch.view_as_real(spectrogram) self.assert_grad(transform, [spectrogram]) def test_psd(self): diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index aa1f38174c..5ba9591823 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -124,17 +124,13 @@ def test_batch_lfcc(self): self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5) - @parameterized.expand([(True, ), (False, )]) - def test_batch_TimeStretch(self, test_pseudo_complex): + def test_batch_TimeStretch(self): rate = 2 num_freq = 1025 num_frames = 400 batch = 3 spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64) - if test_pseudo_complex: - spec = torch.view_as_real(spec) - transform = T.TimeStretch( fixed_rate=rate, n_freq=num_freq, diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index f10f2bf69f..3ae958dc76 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -120,16 +120,15 @@ def test_SpectralCentroid(self): waveform = common_utils.get_whitenoise(sample_rate=sample_rate) self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) - @parameterized.expand([(True, ), (False, )]) - def test_TimeStretch(self, test_pseudo_complex): + def test_TimeStretch(self): n_freq = 400 hop_length = 512 fixed_rate = 1.3 - tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2))) + tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat) self._assert_consistency_complex( T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), tensor, - test_pseudo_complex + False, ) def test_PitchShift(self): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index aacfc80d7e..84a70af110 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -714,8 +714,7 @@ def phase_vocoder( Args: complex_specgrams (Tensor): - Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` - or a tensor of dimension `(..., freq, num_frame)` with complex dtype. + A tensor of dimension `(..., freq, num_frame)` with complex dtype. rate (float): Speed-up factor phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)` @@ -724,7 +723,7 @@ def phase_vocoder( Stretched spectrogram. The resulting tensor is of the same dtype as the input spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``. - Example - With Tensor of complex dtype + Example >>> freq, hop_length = 1025, 512 >>> # (channel, freq, time) >>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat) @@ -734,41 +733,10 @@ def phase_vocoder( >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) >>> x.shape # with 231 == ceil(300 / 1.3) torch.Size([2, 1025, 231]) - - Example - With Tensor of real dtype and extra dimension for complex field - >>> freq, hop_length = 1025, 512 - >>> # (channel, freq, time, complex=2) - >>> complex_specgrams = torch.randn(2, freq, 300, 2) - >>> rate = 1.3 # Speed up by 30% - >>> phase_advance = torch.linspace( - >>> 0, math.pi * hop_length, freq)[..., None] - >>> x = phase_vocoder(complex_specgrams, rate, phase_advance) - >>> x.shape # with 231 == ceil(300 / 1.3) - torch.Size([2, 1025, 231, 2]) """ if rate == 1.0: return complex_specgrams - if not complex_specgrams.is_complex(): - warnings.warn( - "The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and " - "`torchaudio.transforms.TimeStretch` is now deprecated and will be removed " - "from 0.11 release." - "Please migrate to native complex type by converting the input tensor with " - "`torch.view_as_complex`. " - "Please refer to https://github.com/pytorch/audio/issues/1337 " - "for more details about torchaudio's plan to migrate to native complex type." - ) - if complex_specgrams.size(-1) != 2: - raise ValueError( - "complex_specgrams must be either native complex tensors or " - "real valued tensors with shape (..., 2)") - - is_complex = complex_specgrams.is_complex() - - if not is_complex: - complex_specgrams = torch.view_as_complex(complex_specgrams) - # pack batch shape = complex_specgrams.size() complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:])) @@ -813,9 +781,6 @@ def phase_vocoder( # unpack batch complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:]) - - if not is_complex: - return torch.view_as_real(complex_specgrams_stretch) return complex_specgrams_stretch diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 04fc9d8e1e..bb64a43b14 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -972,8 +972,7 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] = r""" Args: complex_specgrams (Tensor): - Either a real tensor of dimension of `(..., freq, num_frame, complex=2)` - or a tensor of dimension `(..., freq, num_frame)` with complex dtype. + A tensor of dimension `(..., freq, num_frame)` with complex dtype. overriding_rate (float or None, optional): speed up to apply to this batch. If no rate is passed, use ``self.fixed_rate``. (Default: ``None``) From 8cf40ef606bf4c32be439994ee6b9aa1a8bc47d6 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 3 Nov 2021 15:41:59 -0400 Subject: [PATCH 2/2] fix --- .../functional/torchscript_consistency_impl.py | 11 +++-------- .../transforms/batch_consistency_test.py | 6 +++--- .../transforms/torchscript_consistency_impl.py | 18 +++++++++++------- 3 files changed, 17 insertions(+), 18 deletions(-) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 5a23f76912..9746f8ac69 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -30,14 +30,11 @@ def _assert_consistency(self, func, tensor, shape_only=False): output = output.shape self.assertEqual(ts_output, output) - def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): + def _assert_consistency_complex(self, func, tensor): assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) ts_func = torch_script(func) - if test_pseudo_complex: - tensor = torch.view_as_real(tensor) - torch.random.manual_seed(40) output = func(tensor) @@ -642,16 +639,14 @@ def func_beta(tensor): def test_phase_vocoder(self): def func(tensor): - is_complex = tensor.is_complex() - - n_freq = tensor.size(-2 if is_complex else -3) + n_freq = tensor.size(-2) rate = 0.5 hop_length = 256 phase_advance = torch.linspace( 0, 3.14 * hop_length, n_freq, - dtype=(torch.real(tensor) if is_complex else tensor).dtype, + dtype=torch.real(tensor).dtype, device=tensor.device, )[..., None] return F.phase_vocoder(tensor, rate, phase_advance) diff --git a/test/torchaudio_unittest/transforms/batch_consistency_test.py b/test/torchaudio_unittest/transforms/batch_consistency_test.py index 5ba9591823..f1ba7a98eb 100644 --- a/test/torchaudio_unittest/transforms/batch_consistency_test.py +++ b/test/torchaudio_unittest/transforms/batch_consistency_test.py @@ -127,13 +127,13 @@ def test_batch_lfcc(self): def test_batch_TimeStretch(self): rate = 2 num_freq = 1025 - num_frames = 400 batch = 3 - spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64) + tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch) + spec = common_utils.get_spectrogram(tensor, n_fft=num_freq) transform = T.TimeStretch( fixed_rate=rate, - n_freq=num_freq, + n_freq=num_freq // 2 + 1, hop_length=512 ) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index 3ae958dc76..3d73eafff7 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -24,15 +24,13 @@ def _assert_consistency(self, transform, tensor, *args): ts_output = ts_transform(tensor, *args) self.assertEqual(ts_output, output) - def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args): + def _assert_consistency_complex(self, transform, tensor, *args): assert tensor.is_complex() tensor = tensor.to(device=self.device, dtype=self.complex_dtype) transform = transform.to(device=self.device, dtype=self.dtype) ts_transform = torch_script(transform) - if test_pseudo_complex: - tensor = torch.view_as_real(tensor) output = transform(tensor, *args) ts_output = ts_transform(tensor, *args) self.assertEqual(ts_output, output) @@ -121,14 +119,20 @@ def test_SpectralCentroid(self): self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform) def test_TimeStretch(self): - n_freq = 400 + n_fft = 1025 + n_freq = n_fft // 2 + 1 hop_length = 512 fixed_rate = 1.3 tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat) + batch = 10 + num_channels = 2 + + waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels) + tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft) + tensor = tensor.reshape(batch, num_channels, n_freq, -1) self._assert_consistency_complex( T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate), tensor, - False, ) def test_PitchShift(self): @@ -151,7 +155,7 @@ def test_PSD_with_mask(self): spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100) spectrogram = spectrogram.to(self.device) mask = torch.rand(spectrogram.shape[-2:], device=self.device) - self._assert_consistency_complex(T.PSD(), spectrogram, False, mask) + self._assert_consistency_complex(T.PSD(), spectrogram, mask) class TransformsFloat32Only(TestBaseMixin): @@ -187,5 +191,5 @@ def test_MVDR(self, solution, online): mask_n = torch.rand(spectrogram.shape[-2:], device=self.device) self._assert_consistency_complex( T.MVDR(solution=solution, online=online), - spectrogram, False, mask_s, mask_n + spectrogram, mask_s, mask_n )