Skip to content

Commit 57c7f7b

Browse files
committed
[BC-Breaking] Default to native complex type when returning raw spectrogram
Part of #1337 . - This code changes the return type of spectrogram to be native complex dtype, when (and only when) returning raw (complex-valued) spectrogram. - Change `return_complex=False` to `return_complex=True` in spectrogram ops. - `return_complex` is only effective when `power` is `None`. It is ignored for cases where `power` is not `None`. Because the returned Tensor is power spectrogram, which is real-valued Tensors.
1 parent 6882342 commit 57c7f7b

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

test/torchaudio_unittest/functional/torchscript_consistency_impl.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,20 +51,34 @@ def _assert_consistency_complex(self, func, tensor, test_pseudo_complex=False):
5151

5252
self.assertEqual(ts_output, output)
5353

54-
def test_spectrogram(self):
54+
def test_spectrogram_complex(self):
5555
def func(tensor):
5656
n_fft = 400
5757
ws = 400
5858
hop = 200
5959
pad = 0
6060
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
61-
power = 2.
61+
power = None
6262
normalize = False
6363
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
6464

6565
tensor = common_utils.get_whitenoise()
6666
self._assert_consistency(func, tensor)
6767

68+
def test_spectrogram_real(self):
69+
def func(tensor):
70+
n_fft = 400
71+
ws = 400
72+
hop = 200
73+
pad = 0
74+
window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
75+
power = 2.
76+
normalize = False
77+
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize, return_complex=False)
78+
79+
tensor = common_utils.get_whitenoise()
80+
self._assert_consistency(func, tensor)
81+
6882
@skipIfRocm
6983
def test_griffinlim(self):
7084
def func(tensor):

torchaudio/functional/functional.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def spectrogram(
4949
center: bool = True,
5050
pad_mode: str = "reflect",
5151
onesided: bool = True,
52-
return_complex: bool = False,
52+
return_complex: bool = True,
5353
) -> Tensor:
5454
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
5555
The spectrogram can be either magnitude-only or complex.

torchaudio/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def __init__(self,
8080
center: bool = True,
8181
pad_mode: str = "reflect",
8282
onesided: bool = True,
83-
return_complex: bool = False) -> None:
83+
return_complex: bool = True) -> None:
8484
super(Spectrogram, self).__init__()
8585
self.n_fft = n_fft
8686
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1

0 commit comments

Comments
 (0)