Skip to content

Commit 39edd48

Browse files
committed
Update torch.rfft to torch.fft.rfft and complex tensor
1 parent 0f80bcf commit 39edd48

File tree

3 files changed

+42
-17
lines changed

3 files changed

+42
-17
lines changed

torchaudio/_internal/fft.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
"""Compatibility module for fft-related functions
2+
3+
In PyTorch 1.7, the new `torch.fft` module was introduced.
4+
5+
To use this new module, one has to explicitly import `torch.fft`. however this will change
6+
the reference `torch.fft` is pointing from function to module.
7+
And this change takes effect not only in the client code but also in already-imported libraries too.
8+
Similarly, if a library does the explicit import, the rest of the application code must use the
9+
`torch.fft.fft` function.
10+
11+
For this reason, to migrate the deprecated functions of fft-family, we need to use the new
12+
implementation under `torch.fft` without explicitly importing `torch.fft` module.
13+
14+
This module provides a simple interface for the migration, abstracting away
15+
the access to the underlying C functions.
16+
17+
Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent
18+
the new module, we can get rid of this module and call functions under `torch.fft` directly.
19+
"""
20+
from typing import Optional
21+
22+
import torch
23+
24+
25+
def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor:
26+
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
27+
return torch._C._fft.fft_rfft(input, n, dim, norm)

torchaudio/compliance/kaldi.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22

33
import math
44
import torch
5-
import torchaudio
65
from torch import Tensor
76

7+
import torchaudio
8+
import torchaudio._internal.fft
9+
810
__all__ = [
911
'get_mel_banks',
1012
'inverse_mel_scale',
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor,
289291
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
290292

291293
# size (m, padded_window_size // 2 + 1, 2)
292-
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
294+
fft = torchaudio._internal.fft.rfft(strided_input)
293295

294296
# Convert the FFT into a power spectrum
295-
power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1)
297+
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
296298
power_spectrum[:, 0] = signal_log_energy
297299

298300
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor,
570572
waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff,
571573
snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient)
572574

573-
# size (m, padded_window_size // 2 + 1, 2)
574-
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
575-
576-
power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1)
577-
if not use_power:
578-
power_spectrum = power_spectrum.pow(0.5)
575+
# size (m, padded_window_size // 2 + 1)
576+
spectrum = torchaudio._internal.fft.rfft(strided_input).abs()
577+
if use_power:
578+
spectrum = spectrum.pow(2.)
579579

580580
# size (num_mel_bins, padded_window_size // 2)
581581
mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency,
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor,
586586
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0)
587587

588588
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
589-
mel_energies = torch.mm(power_spectrum, mel_energies.T)
589+
mel_energies = torch.mm(spectrum, mel_energies.T)
590590
if use_log_fbank:
591591
# avoid log of zero (which should be prevented anyway by dithering)
592592
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()

torchaudio/functional.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import torch
88
from torch import Tensor
9+
import torchaudio._internal.fft
910

1011
__all__ = [
1112
"spectrogram",
@@ -2073,7 +2074,7 @@ def _measure(
20732074
dftBuf[measure_len_ws:dft_len_ws].zero_()
20742075

20752076
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
2076-
_dftBuf = torch.rfft(dftBuf, 1)
2077+
_dftBuf = torchaudio._internal.fft.rfft(dftBuf)
20772078

20782079
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
20792080
_dftBuf[:spectrum_start].zero_()
@@ -2082,7 +2083,7 @@ def _measure(
20822083
if boot_count >= 0 \
20832084
else measure_smooth_time_mult
20842085

2085-
_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
2086+
_d = _dftBuf[spectrum_start:spectrum_end].abs()
20862087
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
20872088
_d = spectrum[spectrum_start:spectrum_end] ** 2
20882089

@@ -2106,12 +2107,9 @@ def _measure(
21062107
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
21072108

21082109
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
2109-
_cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1)
2110+
_cepstrum_Buf = torchaudio._internal.fft.rfft(_cepstrum_Buf)
21102111

2111-
result: float = float(torch.sum(
2112-
complex_norm(
2113-
_cepstrum_Buf[cepstrum_start:cepstrum_end],
2114-
power=2.0)))
2112+
result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
21152113
result = \
21162114
math.log(result / (cepstrum_end - cepstrum_start)) \
21172115
if result > 0 \

0 commit comments

Comments
 (0)