Skip to content

Commit c0be0ca

Browse files
committed
Update spectrogram to use complex
1 parent 47d97e3 commit c0be0ca

File tree

2 files changed

+38
-9
lines changed

2 files changed

+38
-9
lines changed

torchaudio/functional/functional.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,8 @@ def spectrogram(
4242
normalized: bool,
4343
center: bool = True,
4444
pad_mode: str = "reflect",
45-
onesided: bool = True
45+
onesided: bool = True,
46+
return_complex: bool = False,
4647
) -> Tensor:
4748
r"""Create a spectrogram or a batch of spectrograms from a raw audio signal.
4849
The spectrogram can be either magnitude-only or complex.
@@ -65,12 +66,22 @@ def spectrogram(
6566
:attr:`center` is ``True``. Default: ``"reflect"``
6667
onesided (bool, optional): controls whether to return half of results to
6768
avoid redundancy. Default: ``True``
69+
return_complex (bool, optional):
70+
``return_complex = True``, this function returns the resulting Tensor in
71+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
72+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
73+
When ``power`` is provided, the value must be False, as the resulting
74+
Tensor represents real-valued power.
6875
6976
Returns:
7077
Tensor: Dimension (..., freq, time), freq is
7178
``n_fft // 2 + 1`` and ``n_fft`` is the number of
7279
Fourier bins, and time is the number of window hops (n_frame).
7380
"""
81+
if power is not None and return_complex:
82+
raise ValueError(
83+
'When `power` is provided, the return value is real-valued. '
84+
'Therefore, `return_complex` must be False.')
7485

7586
if pad > 0:
7687
# TODO add "with torch.no_grad():" back when JIT supports it
@@ -103,7 +114,9 @@ def spectrogram(
103114
if power == 1.0:
104115
return spec_f.abs()
105116
return spec_f.abs().pow(power)
106-
return torch.view_as_real(spec_f)
117+
if not return_complex:
118+
return torch.view_as_real(spec_f)
119+
return spec_f
107120

108121

109122
def griffinlim(

torchaudio/transforms.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ class Spectrogram(torch.nn.Module):
5454
:attr:`center` is ``True``. Default: ``"reflect"``
5555
onesided (bool, optional): controls whether to return half of results to
5656
avoid redundancy Default: ``True``
57+
return_complex (bool, optional):
58+
``return_complex = True``, this function returns the resulting Tensor in
59+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
60+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
61+
When ``power`` is provided, the value must be False, as the resulting
62+
Tensor represents real-valued power.
5763
"""
5864
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
5965

@@ -68,7 +74,8 @@ def __init__(self,
6874
wkwargs: Optional[dict] = None,
6975
center: bool = True,
7076
pad_mode: str = "reflect",
71-
onesided: bool = True) -> None:
77+
onesided: bool = True,
78+
return_complex: bool = False) -> None:
7279
super(Spectrogram, self).__init__()
7380
self.n_fft = n_fft
7481
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
@@ -83,6 +90,7 @@ def __init__(self,
8390
self.center = center
8491
self.pad_mode = pad_mode
8592
self.onesided = onesided
93+
self.return_complex = return_complex
8694

8795
def forward(self, waveform: Tensor) -> Tensor:
8896
r"""
@@ -105,7 +113,8 @@ def forward(self, waveform: Tensor) -> Tensor:
105113
self.normalized,
106114
self.center,
107115
self.pad_mode,
108-
self.onesided
116+
self.onesided,
117+
self.return_compex,
109118
)
110119

111120

@@ -430,7 +439,8 @@ def __init__(self,
430439
window_fn: Callable[..., Tensor] = torch.hann_window,
431440
power: Optional[float] = 2.,
432441
normalized: bool = False,
433-
wkwargs: Optional[dict] = None) -> None:
442+
wkwargs: Optional[dict] = None,
443+
return_complex: bool = False) -> None:
434444
super(MelSpectrogram, self).__init__()
435445
self.sample_rate = sample_rate
436446
self.n_fft = n_fft
@@ -442,10 +452,16 @@ def __init__(self,
442452
self.n_mels = n_mels # number of mel frequency bins
443453
self.f_max = f_max
444454
self.f_min = f_min
445-
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
446-
hop_length=self.hop_length,
447-
pad=self.pad, window_fn=window_fn, power=self.power,
448-
normalized=self.normalized, wkwargs=wkwargs)
455+
self.spectrogram = Spectrogram(
456+
n_fft=self.n_fft,
457+
win_length=self.win_length,
458+
hop_length=self.hop_length,
459+
pad=self.pad,
460+
window_fn=window_fn,
461+
power=self.power,
462+
normalized=self.normalized,
463+
wkwargs=wkwargs,
464+
return_complex=return_complex,)
449465
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
450466

451467
def forward(self, waveform: Tensor) -> Tensor:

0 commit comments

Comments
 (0)