Skip to content

Commit 21e9316

Browse files
committed
Add return_complex to MelSpectrogram
1 parent 5e2a101 commit 21e9316

File tree

2 files changed

+29
-7
lines changed

2 files changed

+29
-7
lines changed

torchaudio/functional/functional.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def spectrogram(
6666
:attr:`center` is ``True``. Default: ``"reflect"``
6767
onesided (bool, optional): controls whether to return half of results to
6868
avoid redundancy. Default: ``True``
69+
return_complex (bool, optional):
70+
``return_complex = True``, this function returns the resulting Tensor in
71+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
72+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
73+
When ``power`` is provided, the value must be False, as the resulting
74+
Tensor represents real-valued power.
6975
7076
Returns:
7177
Tensor: Dimension (..., freq, time), freq is

torchaudio/transforms.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,12 @@ class Spectrogram(torch.nn.Module):
5454
:attr:`center` is ``True``. Default: ``"reflect"``
5555
onesided (bool, optional): controls whether to return half of results to
5656
avoid redundancy Default: ``True``
57+
return_complex (bool, optional):
58+
``return_complex = True``, this function returns the resulting Tensor in
59+
complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
60+
dimension for real and imaginary parts. (see ``torch.view_as_real``).
61+
When ``power`` is provided, the value must be False, as the resulting
62+
Tensor represents real-valued power.
5763
"""
5864
__constants__ = ['n_fft', 'win_length', 'hop_length', 'pad', 'power', 'normalized']
5965

@@ -68,7 +74,8 @@ def __init__(self,
6874
wkwargs: Optional[dict] = None,
6975
center: bool = True,
7076
pad_mode: str = "reflect",
71-
onesided: bool = True) -> None:
77+
onesided: bool = True,
78+
return_complex: bool = False) -> None:
7279
super(Spectrogram, self).__init__()
7380
self.n_fft = n_fft
7481
# number of FFT bins. the returned STFT result will have n_fft // 2 + 1
@@ -83,6 +90,7 @@ def __init__(self,
8390
self.center = center
8491
self.pad_mode = pad_mode
8592
self.onesided = onesided
93+
self.return_complex = return_complex
8694

8795
def forward(self, waveform: Tensor) -> Tensor:
8896
r"""
@@ -105,7 +113,8 @@ def forward(self, waveform: Tensor) -> Tensor:
105113
self.normalized,
106114
self.center,
107115
self.pad_mode,
108-
self.onesided
116+
self.onesided,
117+
self.return_compex,
109118
)
110119

111120

@@ -430,7 +439,8 @@ def __init__(self,
430439
window_fn: Callable[..., Tensor] = torch.hann_window,
431440
power: Optional[float] = 2.,
432441
normalized: bool = False,
433-
wkwargs: Optional[dict] = None) -> None:
442+
wkwargs: Optional[dict] = None,
443+
return_complex: bool = False) -> None:
434444
super(MelSpectrogram, self).__init__()
435445
self.sample_rate = sample_rate
436446
self.n_fft = n_fft
@@ -442,10 +452,16 @@ def __init__(self,
442452
self.n_mels = n_mels # number of mel frequency bins
443453
self.f_max = f_max
444454
self.f_min = f_min
445-
self.spectrogram = Spectrogram(n_fft=self.n_fft, win_length=self.win_length,
446-
hop_length=self.hop_length,
447-
pad=self.pad, window_fn=window_fn, power=self.power,
448-
normalized=self.normalized, wkwargs=wkwargs)
455+
self.spectrogram = Spectrogram(
456+
n_fft=self.n_fft,
457+
win_length=self.win_length,
458+
hop_length=self.hop_length,
459+
pad=self.pad,
460+
window_fn=window_fn,
461+
power=self.power,
462+
normalized=self.normalized,
463+
wkwargs=wkwargs,
464+
return_complex=return_complex,)
449465
self.mel_scale = MelScale(self.n_mels, self.sample_rate, self.f_min, self.f_max, self.n_fft // 2 + 1)
450466

451467
def forward(self, waveform: Tensor) -> Tensor:

0 commit comments

Comments
 (0)