Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):

self.assertEqual(ts_output, output)

def test_spectrogram_complex(self):
def test_spectrogram(self):
def func(tensor):
n_fft = 400
ws = 400
Expand All @@ -61,21 +61,7 @@ def func(tensor):
tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)

def test_spectrogram_real(self):
def func(tensor):
n_fft = 400
ws = 400
hop = 200
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2.
normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)

tensor = common_utils.get_whitenoise()
self._assert_consistency(func, tensor)

def test_inverse_spectrogram_complex(self):
def test_inverse_spectrogram(self):
def func(tensor):
length = 400
n_fft = 400
Expand All @@ -90,22 +76,6 @@ def func(tensor):
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
self._assert_consistency_complex(func, tensor)

def test_inverse_spectrogram_real(self):
def func(tensor):
length = 400
n_fft = 400
hop = 200
ws = 400
pad = 0
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
normalize = False
return F.inverse_spectrogram(tensor, length, pad, window, n_fft, hop, ws, normalize)

waveform = common_utils.get_whitenoise(sample_rate=8000, duration=0.05)
tensor = common_utils.get_spectrogram(waveform, n_fft=400, hop_length=200)
tensor = torch.view_as_real(tensor)
self._assert_consistency(func, tensor)

@skipIfRocm
def test_griffinlim(self):
def func(tensor):
Expand Down
6 changes: 1 addition & 5 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,11 @@ def test_spectrogram(self, kwargs):
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
self.assert_grad(transform, [waveform], nondet_tol=1e-10)

@parameterized.expand([(False, ), (True, )])
def test_inverse_spectrogram(self, return_complex):
def test_inverse_spectrogram(self):
# create a realistic input:
waveform = get_whitenoise(sample_rate=8000, duration=0.05, n_channels=2)
length = waveform.shape[-1]
spectrogram = get_spectrogram(waveform, n_fft=400)
if not return_complex:
spectrogram = torch.view_as_real(spectrogram)

# test
inv_transform = T.InverseSpectrogram(n_fft=400)
self.assert_grad(inv_transform, [spectrogram, length])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,6 @@ def test_InverseSpectrogram(self):
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
self._assert_consistency_complex(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)

def test_InverseSpectrogram_pseudocomplex(self):
tensor = common_utils.get_whitenoise(sample_rate=8000)
spectrogram = common_utils.get_spectrogram(tensor, n_fft=400, hop_length=100)
spectrogram = torch.view_as_real(spectrogram)
self._assert_consistency(T.InverseSpectrogram(n_fft=400, hop_length=100), spectrogram)

@skipIfRocm
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
Expand Down
33 changes: 8 additions & 25 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def spectrogram(
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = True,
return_complex: Optional[bool] = None,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Expand All @@ -77,25 +77,18 @@ def spectrogram(
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
return_complex (bool, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_complex (bool, optional):
return_complex (bool or None, optional):

Copy link
Contributor Author

@mthrok mthrok Nov 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here, None is is not a valid input and only used to detect if user pass something, so in docstring, we will keep it as bool.

Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Deprecated and not used.

Returns:
Tensor: Dimension `(..., freq, time)`, freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if power is None and not return_complex:
if return_complex is not None:
warnings.warn(
"The use of pseudo complex type in spectrogram is now deprecated."
"Please migrate to native complex type by providing `return_complex=True`. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
"`return_complex` argument is now deprecated and is not effective."
"`torchaudio.functional.spectrogram(power=None)` always returns a tensor with "
"complex dtype. Please remove the argument in the function call."
)

if pad > 0:
Expand Down Expand Up @@ -129,8 +122,6 @@ def spectrogram(
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
if not return_complex:
return torch.view_as_real(spec_f)
return spec_f


Expand Down Expand Up @@ -172,16 +163,8 @@ def inverse_spectrogram(
Tensor: Dimension `(..., time)`. Least squares estimation of the original signal.
"""

if spectrogram.dtype == torch.float32 or spectrogram.dtype == torch.float64:
warnings.warn(
"The use of pseudo complex type in inverse_spectrogram is now deprecated. "
"Please migrate to native complex type by using a complex tensor as input. "
"If the input is generated via spectrogram() function or transform, please use "
"return_complex=True as an argument to that function. "
"Please refer to https://github.com/pytorch/audio/issues/1337 "
"for more details about torchaudio's plan to migrate to native complex type."
)
spectrogram = torch.view_as_complex(spectrogram)
if not spectrogram.is_complex():
raise ValueError("Expected `spectrogram` to be complex dtype.")

if normalized:
spectrogram = spectrogram * window.pow(2.).sum().sqrt()
Expand Down
18 changes: 8 additions & 10 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,7 @@ class Spectrogram(torch.nn.Module):
onesided (bool, optional): controls whether to return half of results to
avoid redundancy (Default: ``True``)
return_complex (bool, optional):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
return_complex (bool, optional):
return_complex (bool or None, optional):

Indicates whether the resulting complex-valued Tensor should be represented with
native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype
mimicking complex value with an extra dimension for real and imaginary parts.
(See also ``torch.view_as_real``.)
This argument is only effective when ``power=None``. It is ignored for
cases where ``power`` is a number as in those cases, the returned tensor is
power spectrogram, which is a real-valued tensor.
Deprecated and not used.

Example
>>> waveform, sample_rate = torchaudio.load('test.wav', normalize=True)
Expand All @@ -93,7 +87,7 @@ def __init__(self,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True,
return_complex: bool = True) -> None:
return_complex: Optional[bool] = None) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
Expand All @@ -108,7 +102,12 @@ def __init__(self,
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex
if return_complex is not None:
warnings.warn(
"`return_complex` argument is now deprecated and is not effective."
"`torchaudio.transforms.Spectrogram(power=None)` always returns a tensor with "
"complex dtype. Please remove the argument in the function call."
)

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -132,7 +131,6 @@ def forward(self, waveform: Tensor) -> Tensor:
self.center,
self.pad_mode,
self.onesided,
self.return_complex,
)


Expand Down