22
33import math
44import torch
5- import torchaudio
65from torch import Tensor
76
7+ import torchaudio
8+ import torchaudio ._internal .fft
9+
810__all__ = [
911 'get_mel_banks' ,
1012 'inverse_mel_scale' ,
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
289291 snip_edges , raw_energy , energy_floor , dither , remove_dc_offset , preemphasis_coefficient )
290292
291293 # size (m, padded_window_size // 2 + 1, 2)
292- fft = torch . rfft (strided_input , 1 , normalized = False , onesided = True )
294+ fft = torchaudio . _internal . fft . rfft (strided_input )
293295
294296 # Convert the FFT into a power spectrum
295- power_spectrum = torch .max (fft .pow ( 2 ). sum ( 2 ), epsilon ).log () # size (m, padded_window_size // 2 + 1)
297+ power_spectrum = torch .max (fft .abs (). pow ( 2. ), epsilon ).log () # size (m, padded_window_size // 2 + 1)
296298 power_spectrum [:, 0 ] = signal_log_energy
297299
298300 power_spectrum = _subtract_column_mean (power_spectrum , subtract_mean )
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
570572 waveform , padded_window_size , window_size , window_shift , window_type , blackman_coeff ,
571573 snip_edges , raw_energy , energy_floor , dither , remove_dc_offset , preemphasis_coefficient )
572574
573- # size (m, padded_window_size // 2 + 1, 2)
574- fft = torch .rfft (strided_input , 1 , normalized = False , onesided = True )
575-
576- power_spectrum = fft .pow (2 ).sum (2 ) # size (m, padded_window_size // 2 + 1)
577- if not use_power :
578- power_spectrum = power_spectrum .pow (0.5 )
575+ # size (m, padded_window_size // 2 + 1)
576+ spectrum = torchaudio ._internal .fft .rfft (strided_input ).abs ()
577+ if use_power :
578+ spectrum = spectrum .pow (2. )
579579
580580 # size (num_mel_bins, padded_window_size // 2)
581581 mel_energies , _ = get_mel_banks (num_mel_bins , padded_window_size , sample_frequency ,
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
586586 mel_energies = torch .nn .functional .pad (mel_energies , (0 , 1 ), mode = 'constant' , value = 0 )
587587
588588 # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
589- mel_energies = torch .mm (power_spectrum , mel_energies .T )
589+ mel_energies = torch .mm (spectrum , mel_energies .T )
590590 if use_log_fbank :
591591 # avoid log of zero (which should be prevented anyway by dithering)
592592 mel_energies = torch .max (mel_energies , _get_epsilon (device , dtype )).log ()
0 commit comments