@@ -53,6 +53,12 @@ class Spectrogram(torch.nn.Module):
5353            :attr:`center` is ``True``. Default: ``"reflect"`` 
5454        onesided (bool, optional): controls whether to return half of results to 
5555            avoid redundancy Default: ``True`` 
56+         return_complex (bool, optional): 
57+             ``return_complex = True``, this function returns the resulting Tensor in 
58+             complex dtype, otherwise it returns the resulting Tensor in real dtype with extra 
59+             dimension for real and imaginary parts. (see ``torch.view_as_real``). 
60+             When ``power`` is provided, the value must be False, as the resulting 
61+             Tensor represents real-valued power. 
5662    """ 
5763    __constants__  =  ['n_fft' , 'win_length' , 'hop_length' , 'pad' , 'power' , 'normalized' ]
5864
@@ -67,7 +73,8 @@ def __init__(self,
6773                 wkwargs : Optional [dict ] =  None ,
6874                 center : bool  =  True ,
6975                 pad_mode : str  =  "reflect" ,
70-                  onesided : bool  =  True ) ->  None :
76+                  onesided : bool  =  True ,
77+                  return_complex : bool  =  False ) ->  None :
7178        super (Spectrogram , self ).__init__ ()
7279        self .n_fft  =  n_fft 
7380        # number of FFT bins. the returned STFT result will have n_fft // 2 + 1 
@@ -82,6 +89,7 @@ def __init__(self,
8289        self .center  =  center 
8390        self .pad_mode  =  pad_mode 
8491        self .onesided  =  onesided 
92+         self .return_complex  =  return_complex 
8593
8694    def  forward (self , waveform : Tensor ) ->  Tensor :
8795        r""" 
@@ -104,7 +112,8 @@ def forward(self, waveform: Tensor) -> Tensor:
104112            self .normalized ,
105113            self .center ,
106114            self .pad_mode ,
107-             self .onesided 
115+             self .onesided ,
116+             self .return_complex ,
108117        )
109118
110119
@@ -457,7 +466,8 @@ def __init__(self,
457466                 pad_mode : str  =  "reflect" ,
458467                 onesided : bool  =  True ,
459468                 norm : Optional [str ] =  None ,
460-                  mel_scale : str  =  "htk" ) ->  None :
469+                  mel_scale : str  =  "htk" ,
470+                  return_complex : bool  =  False ) ->  None :
461471        super (MelSpectrogram , self ).__init__ ()
462472        self .sample_rate  =  sample_rate 
463473        self .n_fft  =  n_fft 
@@ -469,11 +479,20 @@ def __init__(self,
469479        self .n_mels  =  n_mels   # number of mel frequency bins 
470480        self .f_max  =  f_max 
471481        self .f_min  =  f_min 
472-         self .spectrogram  =  Spectrogram (n_fft = self .n_fft , win_length = self .win_length ,
473-                                        hop_length = self .hop_length ,
474-                                        pad = self .pad , window_fn = window_fn , power = self .power ,
475-                                        normalized = self .normalized , wkwargs = wkwargs ,
476-                                        center = center , pad_mode = pad_mode , onesided = onesided )
482+         self .spectrogram  =  Spectrogram (
483+             n_fft = self .n_fft ,
484+             win_length = self .win_length ,
485+             hop_length = self .hop_length ,
486+             pad = self .pad ,
487+             window_fn = window_fn ,
488+             power = self .power ,
489+             normalized = self .normalized ,
490+             wkwargs = wkwargs ,
491+             center = center ,
492+             pad_mode = pad_mode ,
493+             onesided = onesided ,
494+             return_complex = return_complex ,
495+         )
477496        self .mel_scale  =  MelScale (
478497            self .n_mels ,
479498            self .sample_rate ,
0 commit comments