22
33import  math 
44import  torch 
5- import  torchaudio 
65from  torch  import  Tensor 
76
7+ import  torchaudio 
8+ import  torchaudio ._internal .fft 
9+ 
810__all__  =  [
911    'get_mel_banks' ,
1012    'inverse_mel_scale' ,
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
289291        snip_edges , raw_energy , energy_floor , dither , remove_dc_offset , preemphasis_coefficient )
290292
291293    # size (m, padded_window_size // 2 + 1, 2) 
292-     fft  =  torch . rfft (strided_input ,  1 ,  normalized = False ,  onesided = True )
294+     fft  =  torchaudio . _internal . fft . rfft (strided_input )
293295
294296    # Convert the FFT into a power spectrum 
295-     power_spectrum  =  torch .max (fft .pow ( 2 ). sum ( 2 ), epsilon ).log ()  # size (m, padded_window_size // 2 + 1) 
297+     power_spectrum  =  torch .max (fft .abs (). pow ( 2. ), epsilon ).log ()  # size (m, padded_window_size // 2 + 1) 
296298    power_spectrum [:, 0 ] =  signal_log_energy 
297299
298300    power_spectrum  =  _subtract_column_mean (power_spectrum , subtract_mean )
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
570572        waveform , padded_window_size , window_size , window_shift , window_type , blackman_coeff ,
571573        snip_edges , raw_energy , energy_floor , dither , remove_dc_offset , preemphasis_coefficient )
572574
573-     # size (m, padded_window_size // 2 + 1, 2) 
574-     fft  =  torch .rfft (strided_input , 1 , normalized = False , onesided = True )
575- 
576-     power_spectrum  =  fft .pow (2 ).sum (2 )  # size (m, padded_window_size // 2 + 1) 
577-     if  not  use_power :
578-         power_spectrum  =  power_spectrum .pow (0.5 )
575+     # size (m, padded_window_size // 2 + 1) 
576+     spectrum  =  torchaudio ._internal .fft .rfft (strided_input ).abs ()
577+     if  use_power :
578+         spectrum  =  spectrum .pow (2. )
579579
580580    # size (num_mel_bins, padded_window_size // 2) 
581581    mel_energies , _  =  get_mel_banks (num_mel_bins , padded_window_size , sample_frequency ,
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
586586    mel_energies  =  torch .nn .functional .pad (mel_energies , (0 , 1 ), mode = 'constant' , value = 0 )
587587
588588    # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) 
589-     mel_energies  =  torch .mm (power_spectrum , mel_energies .T )
589+     mel_energies  =  torch .mm (spectrum , mel_energies .T )
590590    if  use_log_fbank :
591591        # avoid log of zero (which should be prevented anyway by dithering) 
592592        mel_energies  =  torch .max (mel_energies , _get_epsilon (device , dtype )).log ()
0 commit comments