From 4b33f817a7a4e7b7c5a3e87dc73428bafee00dbc Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Thu, 8 Oct 2020 23:02:30 +0000 Subject: [PATCH] Update torch.rfft to torch.fft.rfft and complex tensor --- torchaudio/_internal/fft.py | 27 +++++++++++++++++++++++++++ torchaudio/compliance/kaldi.py | 20 ++++++++++---------- torchaudio/functional.py | 12 +++++------- 3 files changed, 42 insertions(+), 17 deletions(-) create mode 100644 torchaudio/_internal/fft.py diff --git a/torchaudio/_internal/fft.py b/torchaudio/_internal/fft.py new file mode 100644 index 0000000000..45350a4972 --- /dev/null +++ b/torchaudio/_internal/fft.py @@ -0,0 +1,27 @@ +"""Compatibility module for fft-related functions + +In PyTorch 1.7, the new `torch.fft` module was introduced. + +To use this new module, one has to explicitly import `torch.fft`. however this will change +the reference `torch.fft` is pointing from function to module. +And this change takes effect not only in the client code but also in already-imported libraries too. +Similarly, if a library does the explicit import, the rest of the application code must use the +`torch.fft.fft` function. + +For this reason, to migrate the deprecated functions of fft-family, we need to use the new +implementation under `torch.fft` without explicitly importing `torch.fft` module. + +This module provides a simple interface for the migration, abstracting away +the access to the underlying C functions. + +Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent +the new module, we can get rid of this module and call functions under `torch.fft` directly. +""" +from typing import Optional + +import torch + + +def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor: + # see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft + return torch._C._fft.fft_rfft(input, n, dim, norm) diff --git a/torchaudio/compliance/kaldi.py b/torchaudio/compliance/kaldi.py index 1160595992..afabebeb6a 100644 --- a/torchaudio/compliance/kaldi.py +++ b/torchaudio/compliance/kaldi.py @@ -2,9 +2,11 @@ import math import torch -import torchaudio from torch import Tensor +import torchaudio +import torchaudio._internal.fft + __all__ = [ 'get_mel_banks', 'inverse_mel_scale', @@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) # size (m, padded_window_size // 2 + 1, 2) - fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) + fft = torchaudio._internal.fft.rfft(strided_input) # Convert the FFT into a power spectrum - power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1) + power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1) power_spectrum[:, 0] = signal_log_energy power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) @@ -570,12 +572,10 @@ def fbank(waveform: Tensor, waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) - # size (m, padded_window_size // 2 + 1, 2) - fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) - - power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1) - if not use_power: - power_spectrum = power_spectrum.pow(0.5) + # size (m, padded_window_size // 2 + 1) + spectrum = torchaudio._internal.fft.rfft(strided_input).abs() + if use_power: + spectrum = spectrum.pow(2.) # size (num_mel_bins, padded_window_size // 2) mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, @@ -586,7 +586,7 @@ def fbank(waveform: Tensor, mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0) # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) - mel_energies = torch.mm(power_spectrum, mel_energies.T) + mel_energies = torch.mm(spectrum, mel_energies.T) if use_log_fbank: # avoid log of zero (which should be prevented anyway by dithering) mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index d8adfe86e6..019f54059b 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -6,6 +6,7 @@ import torch from torch import Tensor +import torchaudio._internal.fft __all__ = [ "spectrogram", @@ -2073,7 +2074,7 @@ def _measure( dftBuf[measure_len_ws:dft_len_ws].zero_() # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); - _dftBuf = torch.rfft(dftBuf, 1) + _dftBuf = torchaudio._internal.fft.rfft(dftBuf) # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); _dftBuf[:spectrum_start].zero_() @@ -2082,7 +2083,7 @@ def _measure( if boot_count >= 0 \ else measure_smooth_time_mult - _d = complex_norm(_dftBuf[spectrum_start:spectrum_end]) + _d = _dftBuf[spectrum_start:spectrum_end].abs() spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) _d = spectrum[spectrum_start:spectrum_end] ** 2 @@ -2106,12 +2107,9 @@ def _measure( _cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); - _cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1) + _cepstrum_Buf = torchaudio._internal.fft.rfft(_cepstrum_Buf) - result: float = float(torch.sum( - complex_norm( - _cepstrum_Buf[cepstrum_start:cepstrum_end], - power=2.0))) + result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2))) result = \ math.log(result / (cepstrum_end - cepstrum_start)) \ if result > 0 \