@@ -54,6 +54,12 @@ class Spectrogram(torch.nn.Module):
5454 :attr:`center` is ``True``. Default: ``"reflect"``
5555 onesided (bool, optional): controls whether to return half of results to
5656 avoid redundancy Default: ``True``
57+ return_complex (bool, optional):
58+ ``return_complex = True``, this function returns the resulting Tensor in
59+ complex dtype, otherwise it returns the resulting Tensor in real dtype with extra
60+ dimension for real and imaginary parts. (see ``torch.view_as_real``).
61+ When ``power`` is provided, the value must be False, as the resulting
62+ Tensor represents real-valued power.
5763 """
5864 __constants__ = ['n_fft' , 'win_length' , 'hop_length' , 'pad' , 'power' , 'normalized' ]
5965
@@ -68,7 +74,8 @@ def __init__(self,
6874 wkwargs : Optional [dict ] = None ,
6975 center : bool = True ,
7076 pad_mode : str = "reflect" ,
71- onesided : bool = True ) -> None :
77+ onesided : bool = True ,
78+ return_complex : bool = False ) -> None :
7279 super (Spectrogram , self ).__init__ ()
7380 self .n_fft = n_fft
7481 # number of FFT bins. the returned STFT result will have n_fft // 2 + 1
@@ -83,6 +90,7 @@ def __init__(self,
8390 self .center = center
8491 self .pad_mode = pad_mode
8592 self .onesided = onesided
93+ self .return_complex = return_complex
8694
8795 def forward (self , waveform : Tensor ) -> Tensor :
8896 r"""
@@ -105,7 +113,8 @@ def forward(self, waveform: Tensor) -> Tensor:
105113 self .normalized ,
106114 self .center ,
107115 self .pad_mode ,
108- self .onesided
116+ self .onesided ,
117+ self .return_compex ,
109118 )
110119
111120
@@ -430,7 +439,8 @@ def __init__(self,
430439 window_fn : Callable [..., Tensor ] = torch .hann_window ,
431440 power : Optional [float ] = 2. ,
432441 normalized : bool = False ,
433- wkwargs : Optional [dict ] = None ) -> None :
442+ wkwargs : Optional [dict ] = None ,
443+ return_complex : bool = False ) -> None :
434444 super (MelSpectrogram , self ).__init__ ()
435445 self .sample_rate = sample_rate
436446 self .n_fft = n_fft
@@ -442,10 +452,16 @@ def __init__(self,
442452 self .n_mels = n_mels # number of mel frequency bins
443453 self .f_max = f_max
444454 self .f_min = f_min
445- self .spectrogram = Spectrogram (n_fft = self .n_fft , win_length = self .win_length ,
446- hop_length = self .hop_length ,
447- pad = self .pad , window_fn = window_fn , power = self .power ,
448- normalized = self .normalized , wkwargs = wkwargs )
455+ self .spectrogram = Spectrogram (
456+ n_fft = self .n_fft ,
457+ win_length = self .win_length ,
458+ hop_length = self .hop_length ,
459+ pad = self .pad ,
460+ window_fn = window_fn ,
461+ power = self .power ,
462+ normalized = self .normalized ,
463+ wkwargs = wkwargs ,
464+ return_complex = return_complex ,)
449465 self .mel_scale = MelScale (self .n_mels , self .sample_rate , self .f_min , self .f_max , self .n_fft // 2 + 1 )
450466
451467 def forward (self , waveform : Tensor ) -> Tensor :
0 commit comments