Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/torchaudio_unittest/transforms/autograd_test_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ def assert_grad(
assert gradgradcheck(transform, inputs_, nondet_tol=nondet_tol)

@parameterized.expand([
({'pad': 0, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': False, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 3, 'normalized': True, 'power': None, 'return_complex': True}, ),
({'pad': 0, 'normalized': False, 'power': None}, ),
({'pad': 3, 'normalized': False, 'power': None}, ),
({'pad': 0, 'normalized': True, 'power': None}, ),
Expand Down
14 changes: 14 additions & 0 deletions test/torchaudio_unittest/transforms/librosa_compatibility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,20 @@ def test_spectrogram(self, n_fft, hop_length, power):
out_torch = spect_transform(sound).squeeze().cpu()
self.assertEqual(out_torch, torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

def test_spectrogram_complex(self):
n_fft = 400
hop_length = 200
sample_rate = 16000
sound = common_utils.get_sinusoid(n_channels=1, sample_rate=sample_rate)
sound_librosa = sound.cpu().numpy().squeeze()
spect_transform = torchaudio.transforms.Spectrogram(
n_fft=n_fft, hop_length=hop_length, power=None, return_complex=True)
out_librosa, _ = librosa.core.spectrum._spectrogram(
y=sound_librosa, n_fft=n_fft, hop_length=hop_length, power=1)

out_torch = spect_transform(sound).squeeze()
self.assertEqual(out_torch.abs(), torch.from_numpy(out_librosa), atol=1e-5, rtol=1e-5)

@parameterized.expand([
param(norm=norm, mel_scale=mel_scale, **p.kwargs)
for p in [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def test_Spectrogram(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(), tensor)

def test_Spectrogram_return_complex(self):
tensor = torch.rand((1, 1000))
self._assert_consistency(T.Spectrogram(power=None, return_complex=True), tensor)

@skipIfRocm
def test_GriffinLim(self):
tensor = torch.rand((1, 201, 6))
Expand Down
17 changes: 15 additions & 2 deletions torchaudio/functional/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ def spectrogram(
normalized: bool,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True
onesided: bool = True,
return_complex: bool = False,
) -> Tensor:
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
The spectrogram can be either magnitude-only or complex.
Expand All @@ -71,12 +72,22 @@ def spectrogram(
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy. Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.

Returns:
Tensor: Dimension (..., freq, time), freq is
``n_fft // 2 + 1`` and ``n_fft`` is the number of
Fourier bins, and time is the number of window hops (n_frame).
"""
if power is not None and return_complex:
raise ValueError(
'When `power` is provided, the return value is real-valued. '
'Therefore, `return_complex` must be False.')

if pad > 0:
# TODO add "with torch.no_grad():" back when JIT supports it
Expand Down Expand Up @@ -109,7 +120,9 @@ def spectrogram(
if power == 1.0:
return spec_f.abs()
return spec_f.abs().pow(power)
return torch.view_as_real(spec_f)
if not return_complex:
return torch.view_as_real(spec_f)
return spec_f


def griffinlim(
Expand Down
13 changes: 11 additions & 2 deletions torchaudio/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@ class Spectrogram(torch.nn.Module):
:attr:`center` is ``True``. Default: ``"reflect"``
onesided (bool, optional): controls whether to return half of results to
avoid redundancy Default: ``True``
return_complex (bool, optional):
``return_complex = True``, this function returns the resulting Tensor in
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
dimension for real and imaginary parts. (see ``torch.view_as_real``).
When ``power`` is provided, the value must be False, as the resulting
Tensor represents real-valued power.
Copy link

@anjali411 anjali411 Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest adding a warning that Spectrogram now supports native complex tensors. Currently we return pseudo complex tensors (..., 2) by default, but that will change in the future and return_complex would be set to True by default ... something along those lines

Copy link
Contributor Author

@mthrok mthrok Apr 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the idea but in domain libraries, it is agreed to not mention the migration plan or future plan in docstring. (even though there are exceptions that I made, when this rule was made)

@cpuhrsch brought this up and this is meant to follow the same approach as PyTorch. But if PyTorch also mentions future plan in docstring, please point us to it.

The general idea behind of it is, this kind of direction should better live in release note or issue.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We try to always mention the new path. To me, the definition of deprecation is two parts:

  1. this isn't recommended anymore
  2. this is what you should do instead

you could, I guess, but those in the release notes by why make it more difficult for your users?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh sorry I misread the comment by @anjali411 as adding the suggested sentence to docstring. Yes, there is nothing goes against on adding it to warning.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will make a follow up PR.

"""
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']

Expand All @@ -66,7 +72,8 @@ def __init__(self,
wkwargs: Optional[dict] = None,
center: bool = True,
pad_mode: str = "reflect",
onesided: bool = True) -> None:
onesided: bool = True,
return_complex: bool = False) -> None:
super(Spectrogram, self).__init__()
self.n_fft = n_fft
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
Expand All @@ -81,6 +88,7 @@ def __init__(self,
self.center = center
self.pad_mode = pad_mode
self.onesided = onesided
self.return_complex = return_complex

def forward(self, waveform: Tensor) -> Tensor:
r"""
Expand All @@ -103,7 +111,8 @@ def forward(self, waveform: Tensor) -> Tensor:
self.normalized,
self.center,
self.pad_mode,
self.onesided
self.onesided,
self.return_complex,
)


Expand Down