diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 3eb869c876..3860e97f41 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -51,20 +51,34 @@ def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): self.assertEqual(ts_output, output) - def test_spectrogram(self): + def test_spectrogram_complex(self): def func(tensor): n_fft = 400 ws = 400 hop = 200 pad = 0 window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) - power = 2. + power = None normalize = False return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) tensor = common_utils.get_whitenoise() self._assert_consistency(func, tensor) + def test_spectrogram_real(self): + def func(tensor): + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + power = 2. + normalize = False + return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False) + + tensor = common_utils.get_whitenoise() + self._assert_consistency(func, tensor) + @skipIfRocm def test_griffinlim(self): def func(tensor): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 70045018d9..cae0874be6 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -49,7 +49,7 @@ def spectrogram( center: bool = True, pad_mode: str = "reflect", onesided: bool = True, - return_complex: bool = False, + return_complex: bool = True, ) -> Tensor: r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. The spectrogram can be either magnitude-only or complex. @@ -76,8 +76,10 @@ def spectrogram( Indicates whether the resulting complex-valued Tensor should be represented with native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype mimicking complex value with an extra dimension for real and imaginary parts. - This argument is only effective when ``power=None``. - See also ``torch.view_as_real``. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is a real-valued tensor. Returns: Tensor: Dimension (..., freq, time), freq is diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d524c373de..fff04174ba 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -63,8 +63,10 @@ class Spectrogram(torch.nn.Module): Indicates whether the resulting complex-valued Tensor should be represented with native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype mimicking complex value with an extra dimension for real and imaginary parts. - This argument is only effective when ``power=None``. - See also ``torch.view_as_real``. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is a real-valued tensor. """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -80,7 +82,7 @@ def __init__(self, center: bool = True, pad_mode: str = "reflect", onesided: bool = True, - return_complex: bool = False) -> None: + return_complex: bool = True) -> None: super(Spectrogram, self).__init__() self.n_fft = n_fft # number of FFT bins. the returned STFT result will have n_fft // 2 + 1