From 44405f3cd8a4ae5c6227f355c39f6daa2d3b0173 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 5 Nov 2020 19:26:25 +0000 Subject: [PATCH 1/5] Update spectrogram to use complex --- torchaudio/functional/functional.py | 17 ++++++++++++-- torchaudio/transforms.py | 35 ++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index a26d7ab7fe..be18ceb3ef 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -48,7 +48,8 @@ def spectrogram( normalized: bool, center: bool = True, pad_mode: str = "reflect", - onesided: bool = True + onesided: bool = True, + return_complex: bool = False, ) -> Tensor: r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. The spectrogram can be either magnitude-only or complex. @@ -71,12 +72,22 @@ def spectrogram( :attr:`center` is ``True``. Default: ``"reflect"`` onesided (bool, optional): controls whether to return half of results to avoid redundancy. Default: ``True`` + return_complex (bool, optional): + ``return_complex = True``, this function returns the resulting Tensor in + complex dtype, otherwise it returns the resulting Tensor in real dtype with extra + dimension for real and imaginary parts. (see ``torch.view_as_real``). + When ``power`` is provided, the value must be False, as the resulting + Tensor represents real-valued power. Returns: Tensor: Dimension (..., freq, time), freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ + if power is not None and return_complex: + raise ValueError( + 'When `power` is provided, the return value is real-valued. ' + 'Therefore, `return_complex` must be False.') if pad > 0: # TODO add "with torch.no_grad():" back when JIT supports it @@ -109,7 +120,9 @@ def spectrogram( if power == 1.0: return spec_f.abs() return spec_f.abs().pow(power) - return torch.view_as_real(spec_f) + if not return_complex: + return torch.view_as_real(spec_f) + return spec_f def griffinlim( diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e5dd5b2210..bfd09ea0ce 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -52,6 +52,12 @@ class Spectrogram(torch.nn.Module): :attr:`center` is ``True``. Default: ``"reflect"`` onesided (bool, optional): controls whether to return half of results to avoid redundancy Default: ``True`` + return_complex (bool, optional): + ``return_complex = True``, this function returns the resulting Tensor in + complex dtype, otherwise it returns the resulting Tensor in real dtype with extra + dimension for real and imaginary parts. (see ``torch.view_as_real``). + When ``power`` is provided, the value must be False, as the resulting + Tensor represents real-valued power. """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -66,7 +72,8 @@ def __init__(self, wkwargs: Optional[dict] = None, center: bool = True, pad_mode: str = "reflect", - onesided: bool = True) -> None: + onesided: bool = True, + return_complex: bool = False) -> None: super(Spectrogram, self).__init__() self.n_fft = n_fft # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 @@ -81,6 +88,7 @@ def __init__(self, self.center = center self.pad_mode = pad_mode self.onesided = onesided + self.return_complex = return_complex def forward(self, waveform: Tensor) -> Tensor: r""" @@ -103,7 +111,8 @@ def forward(self, waveform: Tensor) -> Tensor: self.normalized, self.center, self.pad_mode, - self.onesided + self.onesided, + self.return_complex, ) @@ -456,7 +465,8 @@ def __init__(self, pad_mode: str = "reflect", onesided: bool = True, norm: Optional[str] = None, - mel_scale: str = "htk") -> None: + mel_scale: str = "htk", + return_complex: bool = False) -> None: super(MelSpectrogram, self).__init__() self.sample_rate = sample_rate self.n_fft = n_fft @@ -468,11 +478,20 @@ def __init__(self, self.n_mels = n_mels # number of mel frequency bins self.f_max = f_max self.f_min = f_min - self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, - hop_length=self.hop_length, - pad=self.pad, window_fn=window_fn, power=self.power, - normalized=self.normalized, wkwargs=wkwargs, - center=center, pad_mode=pad_mode, onesided=onesided) + self.spectrogram = Spectrogram( + n_fft=self.n_fft, + win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, + window_fn=window_fn, + power=self.power, + normalized=self.normalized, + wkwargs=wkwargs, + center=center, + pad_mode=pad_mode, + onesided=onesided, + return_complex=return_complex, + ) self.mel_scale = MelScale( self.n_mels, self.sample_rate, From 4f71f539b4d39ad998ffe7ec01be0f7c83723cf9 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:01:38 +0000 Subject: [PATCH 2/5] Add `return_complex` to autograd test --- test/torchaudio_unittest/transforms/autograd_test_impl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 091aeec6ab..e86f4c282d 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -29,6 +29,10 @@ def assert_grad( assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) @parameterized.expand([ + ({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ), ({'pad': 0, 'normalized': False, 'power': None}, ), ({'pad': 3, 'normalized': False, 'power': None}, ), ({'pad': 0, 'normalized': True, 'power': None}, ), From d3a36ccd8a5702b6f52a276bc5faad0803aee777 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:23:23 +0000 Subject: [PATCH 3/5] Remove return_complex from melspectrogram --- torchaudio/transforms.py | 22 ++++++---------------- 1 file changed, 6 insertions(+), 16 deletions(-) diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index bfd09ea0ce..20fd1095aa 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -465,8 +465,7 @@ def __init__(self, pad_mode: str = "reflect", onesided: bool = True, norm: Optional[str] = None, - mel_scale: str = "htk", - return_complex: bool = False) -> None: + mel_scale: str = "htk") -> None: super(MelSpectrogram, self).__init__() self.sample_rate = sample_rate self.n_fft = n_fft @@ -478,20 +477,11 @@ def __init__(self, self.n_mels = n_mels # number of mel frequency bins self.f_max = f_max self.f_min = f_min - self.spectrogram = Spectrogram( - n_fft=self.n_fft, - win_length=self.win_length, - hop_length=self.hop_length, - pad=self.pad, - window_fn=window_fn, - power=self.power, - normalized=self.normalized, - wkwargs=wkwargs, - center=center, - pad_mode=pad_mode, - onesided=onesided, - return_complex=return_complex, - ) + self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length, + hop_length=self.hop_length, + pad=self.pad, window_fn=window_fn, power=self.power, + normalized=self.normalized, wkwargs=wkwargs, + center=center, pad_mode=pad_mode, onesided=onesided) self.mel_scale = MelScale( self.n_mels, self.sample_rate, From f8330973468ce070c729b8cbcde2f7c2eca12593 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:25:45 +0000 Subject: [PATCH 4/5] Add TS test for comples Spectrogram --- .../transforms/torchscript_consistency_impl.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index b10f8068b9..40e9a287dd 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -25,6 +25,10 @@ def test_Spectrogram(self): tensor = torch.rand((1, 1000)) self._assert_consistency(T.Spectrogram(), tensor) + def test_Spectrogram_return_complex(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor) + @skipIfRocm def test_GriffinLim(self): tensor = torch.rand((1, 201, 6)) From 529de055c65f8ed4f2498f10d1cae139d1ee0a72 Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 1 Apr 2021 19:49:05 +0000 Subject: [PATCH 5/5] Add librosa compatibility test --- .../transforms/librosa_compatibility_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test.py index c954100f35..ff2a492a52 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test.py @@ -45,6 +45,20 @@ def test_spectrogram(self, n_fft, hop_length, power): out_torch = spect_transform(sound).squeeze().cpu() self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5) + def test_spectrogram_complex(self): + n_fft = 400 + hop_length = 200 + sample_rate = 16000 + sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate) + sound_librosa = sound.cpu().numpy().squeeze() + spect_transform = torchaudio.transforms.Spectrogram( + n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True) + out_librosa, _ = librosa.core.spectrum._spectrogram( + y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=1) + + out_torch = spect_transform(sound).squeeze() + self.assertEqual(out_torch.abs(), torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5) + @parameterized.expand([ param(norm=norm, mel_scale=mel_scale, **p.kwargs) for p in [