diff --git a/test/torchaudio_unittest/transforms/autograd_test_impl.py b/test/torchaudio_unittest/transforms/autograd_test_impl.py index 091aeec6ab..e86f4c282d 100644 --- a/test/torchaudio_unittest/transforms/autograd_test_impl.py +++ b/test/torchaudio_unittest/transforms/autograd_test_impl.py @@ -29,6 +29,10 @@ def assert_grad( assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol) @parameterized.expand([ + ({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ), + ({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ), + ({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ), ({'pad': 0, 'normalized': False, 'power': None}, ), ({'pad': 3, 'normalized': False, 'power': None}, ), ({'pad': 0, 'normalized': True, 'power': None}, ), diff --git a/test/torchaudio_unittest/transforms/librosa_compatibility_test.py b/test/torchaudio_unittest/transforms/librosa_compatibility_test.py index c954100f35..ff2a492a52 100644 --- a/test/torchaudio_unittest/transforms/librosa_compatibility_test.py +++ b/test/torchaudio_unittest/transforms/librosa_compatibility_test.py @@ -45,6 +45,20 @@ def test_spectrogram(self, n_fft, hop_length, power): out_torch = spect_transform(sound).squeeze().cpu() self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5) + def test_spectrogram_complex(self): + n_fft = 400 + hop_length = 200 + sample_rate = 16000 + sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate) + sound_librosa = sound.cpu().numpy().squeeze() + spect_transform = torchaudio.transforms.Spectrogram( + n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True) + out_librosa, _ = librosa.core.spectrum._spectrogram( + y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=1) + + out_torch = spect_transform(sound).squeeze() + self.assertEqual(out_torch.abs(), torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5) + @parameterized.expand([ param(norm=norm, mel_scale=mel_scale, **p.kwargs) for p in [ diff --git a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py index b10f8068b9..40e9a287dd 100644 --- a/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py +++ b/test/torchaudio_unittest/transforms/torchscript_consistency_impl.py @@ -25,6 +25,10 @@ def test_Spectrogram(self): tensor = torch.rand((1, 1000)) self._assert_consistency(T.Spectrogram(), tensor) + def test_Spectrogram_return_complex(self): + tensor = torch.rand((1, 1000)) + self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor) + @skipIfRocm def test_GriffinLim(self): tensor = torch.rand((1, 201, 6)) diff --git a/torchaudio/functional/functional.py b/torchaudio/functional/functional.py index a26d7ab7fe..be18ceb3ef 100644 --- a/torchaudio/functional/functional.py +++ b/torchaudio/functional/functional.py @@ -48,7 +48,8 @@ def spectrogram( normalized: bool, center: bool = True, pad_mode: str = "reflect", - onesided: bool = True + onesided: bool = True, + return_complex: bool = False, ) -> Tensor: r"""Create a spectrogram or a batch of spectrograms from a raw audio signal. The spectrogram can be either magnitude-only or complex. @@ -71,12 +72,22 @@ def spectrogram( :attr:`center` is ``True``. Default: ``"reflect"`` onesided (bool, optional): controls whether to return half of results to avoid redundancy. Default: ``True`` + return_complex (bool, optional): + ``return_complex = True``, this function returns the resulting Tensor in + complex dtype, otherwise it returns the resulting Tensor in real dtype with extra + dimension for real and imaginary parts. (see ``torch.view_as_real``). + When ``power`` is provided, the value must be False, as the resulting + Tensor represents real-valued power. Returns: Tensor: Dimension (..., freq, time), freq is ``n_fft // 2 + 1`` and ``n_fft`` is the number of Fourier bins, and time is the number of window hops (n_frame). """ + if power is not None and return_complex: + raise ValueError( + 'When `power` is provided, the return value is real-valued. ' + 'Therefore, `return_complex` must be False.') if pad > 0: # TODO add "with torch.no_grad():" back when JIT supports it @@ -109,7 +120,9 @@ def spectrogram( if power == 1.0: return spec_f.abs() return spec_f.abs().pow(power) - return torch.view_as_real(spec_f) + if not return_complex: + return torch.view_as_real(spec_f) + return spec_f def griffinlim( diff --git a/torchaudio/transforms.py b/torchaudio/transforms.py index e5dd5b2210..20fd1095aa 100644 --- a/torchaudio/transforms.py +++ b/torchaudio/transforms.py @@ -52,6 +52,12 @@ class Spectrogram(torch.nn.Module): :attr:`center` is ``True``. Default: ``"reflect"`` onesided (bool, optional): controls whether to return half of results to avoid redundancy Default: ``True`` + return_complex (bool, optional): + ``return_complex = True``, this function returns the resulting Tensor in + complex dtype, otherwise it returns the resulting Tensor in real dtype with extra + dimension for real and imaginary parts. (see ``torch.view_as_real``). + When ``power`` is provided, the value must be False, as the resulting + Tensor represents real-valued power. """ __constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized'] @@ -66,7 +72,8 @@ def __init__(self, wkwargs: Optional[dict] = None, center: bool = True, pad_mode: str = "reflect", - onesided: bool = True) -> None: + onesided: bool = True, + return_complex: bool = False) -> None: super(Spectrogram, self).__init__() self.n_fft = n_fft # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 @@ -81,6 +88,7 @@ def __init__(self, self.center = center self.pad_mode = pad_mode self.onesided = onesided + self.return_complex = return_complex def forward(self, waveform: Tensor) -> Tensor: r""" @@ -103,7 +111,8 @@ def forward(self, waveform: Tensor) -> Tensor: self.normalized, self.center, self.pad_mode, - self.onesided + self.onesided, + self.return_complex, )