Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 3 additions & 8 deletions test/torchaudio_unittest/functional/functional_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,11 +429,8 @@ def test_resample_waveform_downsample_accuracy(self, resampling_method, i):
def test_resample_waveform_upsample_accuracy(self, resampling_method, i):
self._test_resample_waveform_accuracy(up_scale_factor=1.0 + i / 20.0, resampling_method=resampling_method)

@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
@nested_params([0.5, 1.01, 1.3])
def test_phase_vocoder_shape(self, rate):
"""Verify the output shape of phase vocoder"""
hop_length = 256
num_freq = 1025
Expand All @@ -443,8 +440,6 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex):
torch.random.manual_seed(42)
spec = torch.randn(
batch_size, num_freq, num_frames, dtype=self.complex_dtype, device=self.device)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

phase_advance = torch.linspace(
0,
Expand All @@ -456,7 +451,7 @@ def test_phase_vocoder_shape(self, rate, test_pseudo_complex):

assert spec.dim() == spec_stretch.dim()
expected_shape = torch.Size([batch_size, num_freq, int(np.ceil(num_frames / rate))])
output_shape = (torch.view_as_complex(spec_stretch) if test_pseudo_complex else spec_stretch).shape
output_shape = spec_stretch.shape
assert output_shape == expected_shape

@parameterized.expand(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,8 @@ def test_amplitude_to_DB(self):

@unittest.skipIf(not LIBROSA_AVAILABLE, "Librosa not available")
class FunctionalComplex(TestBaseMixin):
@nested_params(
[0.5, 1.01, 1.3],
[True, False],
)
def test_phase_vocoder(self, rate, test_pseudo_complex):
@nested_params([0.5, 1.01, 1.3])
def test_phase_vocoder(self, rate):
hop_length = 256
num_freq = 1025
num_frames = 400
Expand All @@ -147,15 +144,11 @@ def test_phase_vocoder(self, rate, test_pseudo_complex):
device=self.device,
dtype=torch.float64)[..., None]

stretched = F.phase_vocoder(
torch.view_as_real(spec) if test_pseudo_complex else spec,
rate=rate, phase_advance=phase_advance)
stretched = F.phase_vocoder(spec, rate=rate, phase_advance=phase_advance)

expected_stretched = librosa.phase_vocoder(
spec.cpu().numpy(),
rate=rate,
hop_length=hop_length)

self.assertEqual(
torch.view_as_complex(stretched) if test_pseudo_complex else stretched,
torch.from_numpy(expected_stretched))
self.assertEqual(stretched, torch.from_numpy(expected_stretched))
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import torch
import torchaudio.functional as F
from parameterized import parameterized

from torchaudio_unittest import common_utils
from torchaudio_unittest.common_utils import (
Expand Down Expand Up @@ -31,14 +30,11 @@ def _assert_consistency(self, func, tensor, shape_only=False):
output = output.shape
self.assertEqual(ts_output, output)

def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
def _assert_consistency_complex(self, func, tensor):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
ts_func = torch_script(func)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)

torch.random.manual_seed(40)
output = func(tensor)

Expand Down Expand Up @@ -641,25 +637,22 @@ def func_beta(tensor):
self._assert_consistency(func, tensor)
self._assert_consistency(func_beta, tensor)

@parameterized.expand([(True, ), (False, )])
def test_phase_vocoder(self, test_paseudo_complex):
def test_phase_vocoder(self):
def func(tensor):
is_complex = tensor.is_complex()

n_freq = tensor.size(-2 if is_complex else -3)
n_freq = tensor.size(-2)
rate = 0.5
hop_length = 256
phase_advance = torch.linspace(
0,
3.14 * hop_length,
n_freq,
dtype=(torch.real(tensor) if is_complex else tensor).dtype,
dtype=torch.real(tensor).dtype,
device=tensor.device,
)[..., None]
return F.phase_vocoder(tensor, rate, phase_advance)

tensor = torch.view_as_complex(torch.randn(2, 1025, 400, 2))
self._assert_consistency_complex(func, tensor, test_paseudo_complex)
self._assert_consistency_complex(func, tensor)


class FunctionalFloat32Only(TestBaseMixin):
Expand Down
9 changes: 2 additions & 7 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,11 +226,8 @@ def test_timestretch_zeros_fail(self):
spectrogram = get_spectrogram(waveform, n_fft=n_fft, power=None)
self.assert_grad(transform, [spectrogram])

@nested_params(
[0.7, 0.8, 0.9, 1.0, 1.3],
[False, True],
)
def test_timestretch_non_zero(self, rate, test_pseudo_complex):
@nested_params([0.7, 0.8, 0.9, 1.0, 1.3])
def test_timestretch_non_zero(self, rate):
"""Verify that ``T.TimeStretch`` does not fail if it's not close to 0

``T.TimeStrech`` is not differentiable around 0, so this test checks the differentiability
Expand All @@ -254,8 +251,6 @@ def test_timestretch_non_zero(self, rate, test_pseudo_complex):
epsilon = 1e-2
too_close = spectrogram.abs() < epsilon
spectrogram[too_close] = epsilon * spectrogram[too_close] / spectrogram[too_close].abs()
if test_pseudo_complex:
spectrogram = torch.view_as_real(spectrogram)
self.assert_grad(transform, [spectrogram])

def test_psd(self):
Expand Down
12 changes: 4 additions & 8 deletions test/torchaudio_unittest/transforms/batch_consistency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,20 +124,16 @@ def test_batch_lfcc(self):

self.assert_batch_consistency(transform, waveform, atol=1e-4, rtol=1e-5)

@parameterized.expand([(True, ), (False, )])
def test_batch_TimeStretch(self, test_pseudo_complex):
def test_batch_TimeStretch(self):
rate = 2
num_freq = 1025
num_frames = 400
batch = 3

spec = torch.randn(batch, num_freq, num_frames, dtype=torch.complex64)
if test_pseudo_complex:
spec = torch.view_as_real(spec)

tensor = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch)
spec = common_utils.get_spectrogram(tensor, n_fft=num_freq)
transform = T.TimeStretch(
fixed_rate=rate,
n_freq=num_freq,
n_freq=num_freq // 2 + 1,
hop_length=512
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ def _assert_consistency(self, transform, tensor, *args):
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)

def _assert_consistency_complex(self, transform, tensor, test_pseudo_complex=False, *args):
def _assert_consistency_complex(self, transform, tensor, *args):
assert tensor.is_complex()
tensor = tensor.to(device=self.device, dtype=self.complex_dtype)
transform = transform.to(device=self.device, dtype=self.dtype)

ts_transform = torch_script(transform)

if test_pseudo_complex:
tensor = torch.view_as_real(tensor)
output = transform(tensor, *args)
ts_output = ts_transform(tensor, *args)
self.assertEqual(ts_output, output)
Expand Down Expand Up @@ -120,16 +118,21 @@ def test_SpectralCentroid(self):
waveform = common_utils.get_whitenoise(sample_rate=sample_rate)
self._assert_consistency(T.SpectralCentroid(sample_rate=sample_rate), waveform)

@parameterized.expand([(True, ), (False, )])
def test_TimeStretch(self, test_pseudo_complex):
n_freq = 400
def test_TimeStretch(self):
n_fft = 1025
n_freq = n_fft // 2 + 1
hop_length = 512
fixed_rate = 1.3
tensor = torch.view_as_complex(torch.rand((10, 2, n_freq, 10, 2)))
tensor = torch.rand((10, 2, n_freq, 10), dtype=torch.cfloat)
Copy link
Member

@nateanl nateanl Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you change it by using the get_spectrogram method?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Using get_specgrogram first time in a while and now I feel it's not user-friendly...

batch = 10
num_channels = 2

waveform = common_utils.get_whitenoise(sample_rate=8000, n_channels=batch * num_channels)
tensor = common_utils.get_spectrogram(waveform, n_fft=n_fft)
tensor = tensor.reshape(batch, num_channels, n_freq, -1)
self._assert_consistency_complex(
T.TimeStretch(n_freq=n_freq, hop_length=hop_length, fixed_rate=fixed_rate),
tensor,
test_pseudo_complex
)

def test_PitchShift(self):
Expand All @@ -152,7 +155,7 @@ def test_PSD_with_mask(self):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = spectrogram.to(self.device)
mask = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(T.PSD(), spectrogram, False, mask)
self._assert_consistency_complex(T.PSD(), spectrogram, mask)


class TransformsFloat32Only(TestBaseMixin):
Expand Down Expand Up @@ -188,5 +191,5 @@ def test_MVDR(self, solution, online):
mask_n = torch.rand(spectrogram.shape[-2:], device=self.device)
self._assert_consistency_complex(
T.MVDR(solution=solution, online=online),
spectrogram, False, mask_s, mask_n
spectrogram, mask_s, mask_n
)
39 changes: 2 additions & 37 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,8 +714,7 @@ def phase_vocoder(

Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
rate (float): Speed-up factor
phase_advance (Tensor): Expected phase advance in each bin. Dimension of `(freq, 1)`

Expand All @@ -724,7 +723,7 @@ def phase_vocoder(
Stretched spectrogram. The resulting tensor is of the same dtype as the input
spectrogram, but the number of frames is changed to ``ceil(num_frame / rate)``.

Example - With Tensor of complex dtype
Example
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time)
>>> complex_specgrams = torch.randn(2, freq, 300, dtype=torch.cfloat)
Expand All @@ -734,41 +733,10 @@ def phase_vocoder(
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231])

Example - With Tensor of real dtype and extra dimension for complex field
>>> freq, hop_length = 1025, 512
>>> # (channel, freq, time, complex=2)
>>> complex_specgrams = torch.randn(2, freq, 300, 2)
>>> rate = 1.3 # Speed up by 30%
>>> phase_advance = torch.linspace(
>>> 0, math.pi * hop_length, freq)[..., None]
>>> x = phase_vocoder(complex_specgrams, rate, phase_advance)
>>> x.shape # with 231 == ceil(300 / 1.3)
torch.Size([2, 1025, 231, 2])
"""
if rate == 1.0:
return complex_specgrams

if not complex_specgrams.is_complex():
warnings.warn(
"The support for pseudo complex type in `torchaudio.functional.phase_vocoder` and "
"`torchaudio.transforms.TimeStretch` is now deprecated and will be removed "
"from 0.11 release."
"Please migrate to native complex type by converting the input tensor with "
"`torch.view_as_complex`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
if complex_specgrams.size(-1) != 2:
raise ValueError(
"complex_specgrams must be either native complex tensors or "
"real valued tensors with shape (..., 2)")

is_complex = complex_specgrams.is_complex()

if not is_complex:
complex_specgrams = torch.view_as_complex(complex_specgrams)

# pack batch
shape = complex_specgrams.size()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The complex_specgrams can be renamed as specgram since the dtype of it is always complex, but that is BC-Breaking, maybe it's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, unfortunately, it is BC-breaking and we do not have a strong reason to push the BC-breaking here.

complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-2:]))
Expand Down Expand Up @@ -813,9 +781,6 @@ def phase_vocoder(

# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-2] + complex_specgrams_stretch.shape[1:])

if not is_complex:
return torch.view_as_real(complex_specgrams_stretch)
return complex_specgrams_stretch


Expand Down
3 changes: 1 addition & 2 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,8 +972,7 @@ def forward(self, complex_specgrams: Tensor, overriding_rate: Optional[float] =
r"""
Args:
complex_specgrams (Tensor):
Either a real tensor of dimension of `(..., freq, num_frame, complex=2)`
or a tensor of dimension `(..., freq, num_frame)` with complex dtype.
A tensor of dimension `(..., freq, num_frame)` with complex dtype.
overriding_rate (float or None, optional): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``. (Default: ``None``)

Expand Down