Skip to content

Commit 2212bfa

Browse files
committed
Update spectrogram to use complex
1 parent ea85794 commit 2212bfa

File tree

2 files changed

+42
-10
lines changed

2 files changed

+42
-10
lines changed

torchaudio/functional/functional.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def spectrogram(
4747
normalized: bool,
4848
center: bool = True,
4949
pad_mode: str = "reflect",
50-
onesided: bool = True
50+
onesided: bool = True,
51+
return_complex: bool = False,
5152
) -> Tensor:
5253
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
5354
The spectrogram can be either magnitude-only or complex.
@@ -70,12 +71,22 @@ def spectrogram(
7071
:attr:`center` is ``True``. Default: ``"reflect"``
7172
onesided (bool, optional): controls whether to return half of results to
7273
avoid redundancy. Default: ``True``
74+
return_complex (bool, optional):
75+
``return_complex = True``, this function returns the resulting Tensor in
76+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
77+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
78+
When ``power`` is provided, the value must be False, as the resulting
79+
Tensor represents real-valued power.
7380
7481
Returns:
7582
Tensor: Dimension (..., freq, time), freq is
7683
``n_fft // 2 + 1`` and ``n_fft`` is the number of
7784
Fourier bins, and time is the number of window hops (n_frame).
7885
"""
86+
if power is not None and return_complex:
87+
raise ValueError(
88+
'When `power` is provided, the return value is real-valued. '
89+
'Therefore, `return_complex` must be False.')
7990

8091
if pad > 0:
8192
# TODO add "with torch.no_grad():" back when JIT supports it
@@ -108,7 +119,9 @@ def spectrogram(
108119
if power == 1.0:
109120
return spec_f.abs()
110121
return spec_f.abs().pow(power)
111-
return torch.view_as_real(spec_f)
122+
if not return_complex:
123+
return torch.view_as_real(spec_f)
124+
return spec_f
112125

113126

114127
def griffinlim(

torchaudio/transforms.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class Spectrogram(torch.nn.Module):
5353
:attr:`center` is ``True``. Default: ``"reflect"``
5454
onesided (bool, optional): controls whether to return half of results to
5555
avoid redundancy Default: ``True``
56+
return_complex (bool, optional):
57+
``return_complex = True``, this function returns the resulting Tensor in
58+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
59+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
60+
When ``power`` is provided, the value must be False, as the resulting
61+
Tensor represents real-valued power.
5662
"""
5763
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
5864

@@ -67,7 +73,8 @@ def __init__(self,
6773
wkwargs: Optional[dict] = None,
6874
center: bool = True,
6975
pad_mode: str = "reflect",
70-
onesided: bool = True) -> None:
76+
onesided: bool = True,
77+
return_complex: bool = False) -> None:
7178
super(Spectrogram, self).__init__()
7279
self.n_fft = n_fft
7380
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
@@ -82,6 +89,7 @@ def __init__(self,
8289
self.center = center
8390
self.pad_mode = pad_mode
8491
self.onesided = onesided
92+
self.return_complex = return_complex
8593

8694
def forward(self, waveform: Tensor) -> Tensor:
8795
r"""
@@ -104,7 +112,8 @@ def forward(self, waveform: Tensor) -> Tensor:
104112
self.normalized,
105113
self.center,
106114
self.pad_mode,
107-
self.onesided
115+
self.onesided,
116+
self.return_complex,
108117
)
109118

110119

@@ -457,7 +466,8 @@ def __init__(self,
457466
pad_mode: str = "reflect",
458467
onesided: bool = True,
459468
norm: Optional[str] = None,
460-
mel_scale: str = "htk") -> None:
469+
mel_scale: str = "htk",
470+
return_complex: bool = False) -> None:
461471
super(MelSpectrogram, self).__init__()
462472
self.sample_rate = sample_rate
463473
self.n_fft = n_fft
@@ -469,11 +479,20 @@ def __init__(self,
469479
self.n_mels = n_mels # number of mel frequency bins
470480
self.f_max = f_max
471481
self.f_min = f_min
472-
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
473-
hop_length=self.hop_length,
474-
pad=self.pad, window_fn=window_fn, power=self.power,
475-
normalized=self.normalized, wkwargs=wkwargs,
476-
center=center, pad_mode=pad_mode, onesided=onesided)
482+
self.spectrogram = Spectrogram(
483+
n_fft=self.n_fft,
484+
win_length=self.win_length,
485+
hop_length=self.hop_length,
486+
pad=self.pad,
487+
window_fn=window_fn,
488+
power=self.power,
489+
normalized=self.normalized,
490+
wkwargs=wkwargs,
491+
center=center,
492+
pad_mode=pad_mode,
493+
onesided=onesided,
494+
return_complex=return_complex,
495+
)
477496
self.mel_scale = MelScale(
478497
self.n_mels,
479498
self.sample_rate,

0 commit comments

Comments
 (0)