Skip to content

Commit d3e146f

Browse files
authored
[BC-Breaking] Drop pseudo complex support from phase_vocoder / TimeStretch (#1957)
Following the plan #1337, this commit drops the support for pseudo complex type from `F.phase_vocoder` and `T.TimeStretch`.
1 parent 5ec6ada commit d3e146f

File tree

8 files changed

+34
-95
lines changed

8 files changed

+34
-95
lines changed

test/torchaudio_unittest/functional/functional_impl.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,11 +429,8 @@ def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
429429
def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
430430
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)
431431

432-
@nested_params(
433-
[0.5, 1.01, 1.3],
434-
[True, False],
435-
)
436-
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
432+
@nested_params([0.5, 1.01, 1.3])
433+
def test_phase_vocoder_shape(self, rate):
437434
"""Verify the output shape of phase vocoder"""
438435
hop_length = 256
439436
num_freq = 1025
@@ -443,8 +440,6 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
443440
torch.random.manual_seed(42)
444441
spec = torch.randn(
445442
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
446-
if test_pseudo_complex:
447-
spec = torch.view_as_real(spec)
448443

449444
phase_advance = torch.linspace(
450445
0,
@@ -456,7 +451,7 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
456451

457452
assert spec.dim() == spec_stretch.dim()
458453
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
459-
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
454+
output_shape = spec_stretch.shape
460455
assert output_shape == expected_shape
461456

462457
@parameterized.expand(

test/torchaudio_unittest/functional/librosa_compatibility_test_impl.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -126,11 +126,8 @@ def test_amplitude_to_DB(self):
126126

127127
@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
128128
class FunctionalComplex(TestBaseMixin):
129-
@nested_params(
130-
[0.5, 1.01, 1.3],
131-
[True, False],
132-
)
133-
def test_phase_vocoder(self, rate, test_pseudo_complex):
129+
@nested_params([0.5, 1.01, 1.3])
130+
def test_phase_vocoder(self, rate):
134131
hop_length = 256
135132
num_freq = 1025
136133
num_frames = 400
@@ -147,15 +144,11 @@ def test_phase_vocoder(self, rate, test_pseudo_complex):
147144
device=self.device,
148145
dtype=torch.float64)[..., None]
149146

150-
stretched = F.phase_vocoder(
151-
torch.view_as_real(spec) if test_pseudo_complex else spec,
152-
rate=rate, phase_advance=phase_advance)
147+
stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)
153148

154149
expected_stretched = librosa.phase_vocoder(
155150
spec.cpu().numpy(),
156151
rate=rate,
157152
hop_length=hop_length)
158153

159-
self.assertEqual(
160-
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
161-
torch.from_numpy(expected_stretched))
154+
self.assertEqual(stretched, torch.from_numpy(expected_stretched))

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import torch
55
import torchaudio.functional as F
6-
from parameterized import parameterized
76

87
from torchaudio_unittest import common_utils
98
from torchaudio_unittest.common_utils import (
@@ -31,14 +30,11 @@ def _assert_consistency(self, func, tensor, shape_only=False):
3130
output = output.shape
3231
self.assertEqual(ts_output, output)
3332

34-
def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
33+
def _assert_consistency_complex(self, func, tensor):
3534
assert tensor.is_complex()
3635
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
3736
ts_func = torch_script(func)
3837

39-
if test_pseudo_complex:
40-
tensor = torch.view_as_real(tensor)
41-
4238
torch.random.manual_seed(40)
4339
output = func(tensor)
4440

@@ -641,25 +637,22 @@ def func_beta(tensor):
641637
self._assert_consistency(func, tensor)
642638
self._assert_consistency(func_beta, tensor)
643639

644-
@parameterized.expand([(True, ), (False, )])
645-
def test_phase_vocoder(self, test_paseudo_complex):
640+
def test_phase_vocoder(self):
646641
def func(tensor):
647-
is_complex = tensor.is_complex()
648-
649-
n_freq = tensor.size(-2 if is_complex else -3)
642+
n_freq = tensor.size(-2)
650643
rate = 0.5
651644
hop_length = 256
652645
phase_advance = torch.linspace(
653646
0,
654647
3.14 * hop_length,
655648
n_freq,
656-
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
649+
dtype=torch.real(tensor).dtype,
657650
device=tensor.device,
658651
)[..., None]
659652
return F.phase_vocoder(tensor, rate, phase_advance)
660653

661654
tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
662-
self._assert_consistency_complex(func, tensor, test_paseudo_complex)
655+
self._assert_consistency_complex(func, tensor)
663656

664657

665658
class FunctionalFloat32Only(TestBaseMixin):

test/torchaudio_unittest/transforms/autograd_test_impl.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,8 @@ def test_timestretch_zeros_fail(self):
226226
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
227227
self.assert_grad(transform, [spectrogram])
228228

229-
@nested_params(
230-
[0.7, 0.8, 0.9, 1.0, 1.3],
231-
[False, True],
232-
)
233-
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
229+
@nested_params([0.7, 0.8, 0.9, 1.0, 1.3])
230+
def test_timestretch_non_zero(self, rate):
234231
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0
235232
236233
``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
@@ -254,8 +251,6 @@ def test_timestretch_non_zero(self, rate, test_pseudo_complex):
254251
epsilon = 1e-2
255252
too_close = spectrogram.abs() < epsilon
256253
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
257-
if test_pseudo_complex:
258-
spectrogram = torch.view_as_real(spectrogram)
259254
self.assert_grad(transform, [spectrogram])
260255

261256
def test_psd(self):

test/torchaudio_unittest/transforms/batch_consistency_test.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -124,20 +124,16 @@ def test_batch_lfcc(self):
124124

125125
self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)
126126

127-
@parameterized.expand([(True, ), (False, )])
128-
def test_batch_TimeStretch(self, test_pseudo_complex):
127+
def test_batch_TimeStretch(self):
129128
rate = 2
130129
num_freq = 1025
131-
num_frames = 400
132130
batch = 3
133131

134-
spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
135-
if test_pseudo_complex:
136-
spec = torch.view_as_real(spec)
137-
132+
tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
133+
spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
138134
transform = T.TimeStretch(
139135
fixed_rate=rate,
140-
n_freq=num_freq,
136+
n_freq=num_freq // 2 + 1,
141137
hop_length=512
142138
)
143139

test/torchaudio_unittest/transforms/torchscript_consistency_impl.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,13 @@ def _assert_consistency(self, transform, tensor, *args):
2424
ts_output = ts_transform(tensor, *args)
2525
self.assertEqual(ts_output, output)
2626

27-
def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args):
27+
def _assert_consistency_complex(self, transform, tensor, *args):
2828
assert tensor.is_complex()
2929
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
3030
transform = transform.to(device=self.device, dtype=self.dtype)
3131

3232
ts_transform = torch_script(transform)
3333

34-
if test_pseudo_complex:
35-
tensor = torch.view_as_real(tensor)
3634
output = transform(tensor, *args)
3735
ts_output = ts_transform(tensor, *args)
3836
self.assertEqual(ts_output, output)
@@ -120,16 +118,21 @@ def test_SpectralCentroid(self):
120118
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
121119
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)
122120

123-
@parameterized.expand([(True, ), (False, )])
124-
def test_TimeStretch(self, test_pseudo_complex):
125-
n_freq = 400
121+
def test_TimeStretch(self):
122+
n_fft = 1025
123+
n_freq = n_fft // 2 + 1
126124
hop_length = 512
127125
fixed_rate = 1.3
128-
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
126+
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
127+
batch = 10
128+
num_channels = 2
129+
130+
waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels)
131+
tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
132+
tensor = tensor.reshape(batch, num_channels, n_freq, -1)
129133
self._assert_consistency_complex(
130134
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
131135
tensor,
132-
test_pseudo_complex
133136
)
134137

135138
def test_PitchShift(self):
@@ -152,7 +155,7 @@ def test_PSD_with_mask(self):
152155
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
153156
spectrogram = spectrogram.to(self.device)
154157
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
155-
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask)
158+
self._assert_consistency_complex(T.PSD(), spectrogram, mask)
156159

157160

158161
class TransformsFloat32Only(TestBaseMixin):
@@ -188,5 +191,5 @@ def test_MVDR(self, solution, online):
188191
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
189192
self._assert_consistency_complex(
190193
T.MVDR(solution=solution, online=online),
191-
spectrogram, False, mask_s, mask_n
194+
spectrogram, mask_s, mask_n
192195
)

torchaudio/functional/functional.py

Lines changed: 2 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -714,8 +714,7 @@ def phase_vocoder(
714714
715715
Args:
716716
complex_specgrams (Tensor):
717-
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
718-
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
717+
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
719718
rate (float): Speed-up factor
720719
phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`
721720
@@ -724,7 +723,7 @@ def phase_vocoder(
724723
Stretched spectrogram. The resulting tensor is of the same dtype as the input
725724
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.
726725
727-
Example - With Tensor of complex dtype
726+
Example
728727
>>> freq, hop_length = 1025, 512
729728
>>> # (channel, freq, time)
730729
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
@@ -734,41 +733,10 @@ def phase_vocoder(
734733
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
735734
>>> x.shape # with 231 == ceil(300 / 1.3)
736735
torch.Size([2, 1025, 231])
737-
738-
Example - With Tensor of real dtype and extra dimension for complex field
739-
>>> freq, hop_length = 1025, 512
740-
>>> # (channel, freq, time, complex=2)
741-
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
742-
>>> rate = 1.3 # Speed up by 30%
743-
>>> phase_advance = torch.linspace(
744-
>>> 0, math.pi * hop_length, freq)[..., None]
745-
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
746-
>>> x.shape # with 231 == ceil(300 / 1.3)
747-
torch.Size([2, 1025, 231, 2])
748736
"""
749737
if rate == 1.0:
750738
return complex_specgrams
751739

752-
if not complex_specgrams.is_complex():
753-
warnings.warn(
754-
"The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and "
755-
"`torchaudio.transforms.TimeStretch` is now deprecated and will be removed "
756-
"from 0.11 release."
757-
"Please migrate to native complex type by converting the input tensor with "
758-
"`torch.view_as_complex`. "
759-
"Please refer to https://github.com/pytorch/audio/issues/1337 "
760-
"for more details about torchaudio's plan to migrate to native complex type."
761-
)
762-
if complex_specgrams.size(-1) != 2:
763-
raise ValueError(
764-
"complex_specgrams must be either native complex tensors or "
765-
"real valued tensors with shape (..., 2)")
766-
767-
is_complex = complex_specgrams.is_complex()
768-
769-
if not is_complex:
770-
complex_specgrams = torch.view_as_complex(complex_specgrams)
771-
772740
# pack batch
773741
shape = complex_specgrams.size()
774742
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
@@ -813,9 +781,6 @@ def phase_vocoder(
813781

814782
# unpack batch
815783
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])
816-
817-
if not is_complex:
818-
return torch.view_as_real(complex_specgrams_stretch)
819784
return complex_specgrams_stretch
820785

821786

torchaudio/transforms.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -972,8 +972,7 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
972972
r"""
973973
Args:
974974
complex_specgrams (Tensor):
975-
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
976-
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
975+
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
977976
overriding_rate (float or None, optional): speed up to apply to this batch.
978977
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)
979978

0 commit comments

Comments
 (0)