From 57c7f7b16e12000025731d1b56fd15fa4706828c Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Wed, 2 Jun 2021 15:28:21 +0000 Subject: [PATCH 1/3] [BC-Breaking] Default to native complex type when returning raw spectrogram Part of https://github.com/pytorch/audio/issues/1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors. --- .../functional/torchscript_consistency_impl.py | 18 ++++++++++++++++-- torchaudio/functional/functional.py | 2 +- torchaudio/transforms.py | 2 +- 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py index 3eb869c876..3860e97f41 100644 --- a/test/torchaudio_unittest/functional/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/functional/torchscript_consistency_impl.py @@ -51,20 +51,34 @@ def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False): self.assertEqual(ts_output, output) - def test_spectrogram(self): + def test_spectrogram_complex(self): def func(tensor): n_fft = 400 ws = 400 hop = 200 pad = 0 window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) - power = 2. + power = None normalize = False return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize) tensor = common_utils.get_whitenoise() self._assert_consistency(func, tensor) + def test_spectrogram_real(self): + def func(tensor): + n_fft = 400 + ws = 400 + hop = 200 + pad = 0 + window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype) + power = 2. + normalize = False + return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False) + + tensor = common_utils.get_whitenoise() + self._assert_consistency(func, tensor) + @skipIfRocm def test_griffinlim(self): def func(tensor): diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 70045018d9..10e6069eb8 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -49,7 +49,7 @@ def spectrogram( center: bool = True, pad_mode: str = "reflect", onesided: bool = True, - return_complex: bool = False, + return_complex: bool = True, ) -> Tensor: r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. The spectrogram can be either magnitude-only or complex. diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index d524c373de..5a856d4f3a 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -80,7 +80,7 @@ def __init__(self, center: bool = True, pad_mode: str = "reflect", onesided: bool = True, - return_complex: bool = False) -> None: + return_complex: bool = True) -> None: super(Spectrogram, self).__init__() self.n_fft = n_fft # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 From ef648485ef55ce1e5d140bb93cf32044a571d0c8 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 4 Jun 2021 19:21:10 +0000 Subject: [PATCH 2/3] tweak docstring --- torchaudio/functional/functional.py | 6 ++++-- torchaudio/transforms.py | 6 ++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 10e6069eb8..229cfbf12c 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -76,8 +76,10 @@ def spectrogram( Indicates whether the resulting complex-valued Tensor should be represented with native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype mimicking complex value with an extra dimension for real and imaginary parts. - This argument is only effective when ``power=None``. - See also ``torch.view_as_real``. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is real-valued tensors. Returns: Tensor: Dimension (..., freq, time), freq is diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index 5a856d4f3a..e52a104e1e 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -63,8 +63,10 @@ class Spectrogram(torch.nn.Module): Indicates whether the resulting complex-valued Tensor should be represented with native complex dtype, such as `torch.cfloat` and `torch.cdouble`, or real dtype mimicking complex value with an extra dimension for real and imaginary parts. - This argument is only effective when ``power=None``. - See also ``torch.view_as_real``. + (See also ``torch.view_as_real``.) + This argument is only effective when ``power=None``. It is ignored for + cases where ``power`` is a number as in those cases, the returned tensor is + power spectrogram, which is real-valued tensors. """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] From 3cf6d5ab0f141e360a5e1d306a688631b12fc913 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Fri, 4 Jun 2021 19:22:12 +0000 Subject: [PATCH 3/3] fixup! tweak docstring --- torchaudio/functional/functional.py | 2 +- torchaudio/transforms.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index 229cfbf12c..cae0874be6 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -79,7 +79,7 @@ def spectrogram( (See also ``torch.view_as_real``.) This argument is only effective when ``power=None``. It is ignored for cases where ``power`` is a number as in those cases, the returned tensor is - power spectrogram, which is real-valued tensors. + power spectrogram, which is a real-valued tensor. Returns: Tensor: Dimension (..., freq, time), freq is diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e52a104e1e..fff04174ba 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -66,7 +66,7 @@ class Spectrogram(torch.nn.Module): (See also ``torch.view_as_real``.) This argument is only effective when ``power=None``. It is ignored for cases where ``power`` is a number as in those cases, the returned tensor is - power spectrogram, which is real-valued tensors. + power spectrogram, which is a real-valued tensor. """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']