-
Couldn't load subscription status.
- Fork 734
Update torch.rfft to torch.fft.rfft and complex tensor #941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| """Compatibility module for fft-related functions | ||
| In PyTorch 1.7, the new `torch.fft` module was introduced. | ||
| To use this new module, one has to explicitly import `torch.fft`. however this will change | ||
| the reference `torch.fft` is pointing from function to module. | ||
| And this change takes effect not only in the client code but also in already-imported libraries too. | ||
| Similarly, if a library does the explicit import, the rest of the application code must use the | ||
| `torch.fft.fft` function. | ||
| For this reason, to migrate the deprecated functions of fft-family, we need to use the new | ||
| implementation under `torch.fft` without explicitly importing `torch.fft` module. | ||
| This module provides a simple interface for the migration, abstracting away | ||
| the access to the underlying C functions. | ||
| Once the deprecated functions are removed from PyTorch and `torch.fft` starts to always represent | ||
| the new module, we can get rid of this module and call functions under `torch.fft` directly. | ||
| """ | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor: | ||
| # see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft | ||
| return torch._C._fft.fft_rfft(input, n, dim, norm) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,9 +2,11 @@ | |
|
|
||
| import math | ||
| import torch | ||
| import torchaudio | ||
| from torch import Tensor | ||
|
|
||
| import torchaudio | ||
| import torchaudio._internal.fft | ||
|
|
||
| __all__ = [ | ||
| 'get_mel_banks', | ||
| 'inverse_mel_scale', | ||
|
|
@@ -289,10 +291,10 @@ def spectrogram(waveform: Tensor, | |
| snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) | ||
|
|
||
| # size (m, padded_window_size // 2 + 1, 2) | ||
| fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) | ||
| fft = torchaudio._internal.fft.rfft(strided_input) | ||
|
|
||
| # Convert the FFT into a power spectrum | ||
| power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1) | ||
| power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
||
| power_spectrum[:, 0] = signal_log_energy | ||
|
|
||
| power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean) | ||
|
|
@@ -570,12 +572,10 @@ def fbank(waveform: Tensor, | |
| waveform, padded_window_size, window_size, window_shift, window_type, blackman_coeff, | ||
| snip_edges, raw_energy, energy_floor, dither, remove_dc_offset, preemphasis_coefficient) | ||
|
|
||
| # size (m, padded_window_size // 2 + 1, 2) | ||
| fft = torch.rfft(strided_input, 1, normalized=False, onesided=True) | ||
|
|
||
| power_spectrum = fft.pow(2).sum(2) # size (m, padded_window_size // 2 + 1) | ||
| if not use_power: | ||
| power_spectrum = power_spectrum.pow(0.5) | ||
| # size (m, padded_window_size // 2 + 1) | ||
| spectrum = torchaudio._internal.fft.rfft(strided_input).abs() | ||
| if use_power: | ||
| spectrum = spectrum.pow(2.) | ||
|
|
||
| # size (num_mel_bins, padded_window_size // 2) | ||
| mel_energies, _ = get_mel_banks(num_mel_bins, padded_window_size, sample_frequency, | ||
|
|
@@ -586,7 +586,7 @@ def fbank(waveform: Tensor, | |
| mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode='constant', value=0) | ||
|
|
||
| # sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins) | ||
| mel_energies = torch.mm(power_spectrum, mel_energies.T) | ||
| mel_energies = torch.mm(spectrum, mel_energies.T) | ||
| if use_log_fbank: | ||
| # avoid log of zero (which should be prevented anyway by dithering) | ||
| mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log() | ||
|
|
||
| Original file line number | Diff line number | Diff line change | |||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -6,6 +6,7 @@ | ||||||||||
|
|
|||||||||||
| import torch | |||||||||||
| from torch import Tensor | |||||||||||
| import torchaudio._internal.fft | |||||||||||
|
|
|||||||||||
| __all__ = [ | |||||||||||
| "spectrogram", | |||||||||||
|
|
@@ -2073,7 +2074,7 @@ def _measure( | ||||||||||
| dftBuf[measure_len_ws:dft_len_ws].zero_() | |||||||||||
|
|
|||||||||||
| # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); | |||||||||||
| _dftBuf = torch.rfft(dftBuf, 1) | |||||||||||
| _dftBuf = torchaudio._internal.fft.rfft(dftBuf) | |||||||||||
|
|
|||||||||||
| # memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf)); | |||||||||||
| _dftBuf[:spectrum_start].zero_() | |||||||||||
|
|
@@ -2082,7 +2083,7 @@ def _measure( | ||||||||||
| if boot_count >= 0 \ | |||||||||||
| else measure_smooth_time_mult | |||||||||||
|
|
|||||||||||
| _d = complex_norm(_dftBuf[spectrum_start:spectrum_end]) | |||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd keep the change norm/abs separate. In a prior attempt, there was a performance regression, #747. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @vincentqb Can you elaborate why the regression in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll let @anjali411 @mruberry comment on how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Datapoint: torch.abs and torch.norm have separate implementations. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be perfectly fine to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good to know :)
We had In any case, let's just separate concerns and move this to a separate pull request so we don't block the rest on this :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's a significant performance gain (as expected) so I think we should switch to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @anjali411 ! In addition to that I did benchmark for the exact code path that @vincentqb suggested and observed that
PyTorch: 1.8.0a0+edac406 codeOMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
t.abs();
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
t.abs();
"""
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for checking! |
|||||||||||
| _d = _dftBuf[spectrum_start:spectrum_end].abs() | |||||||||||
| spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) | |||||||||||
| _d = spectrum[spectrum_start:spectrum_end] ** 2 | |||||||||||
|
|
|||||||||||
|
|
@@ -2106,12 +2107,9 @@ def _measure( | ||||||||||
| _cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_() | |||||||||||
|
|
|||||||||||
| # lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf); | |||||||||||
| _cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1) | |||||||||||
| _cepstrum_Buf = torchaudio._internal.fft.rfft(_cepstrum_Buf) | |||||||||||
|
|
|||||||||||
| result: float = float(torch.sum( | |||||||||||
| complex_norm( | |||||||||||
| _cepstrum_Buf[cepstrum_start:cepstrum_end], | |||||||||||
| power=2.0))) | |||||||||||
| result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2))) | |||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same |
|||||||||||
| result = \ | |||||||||||
| math.log(result / (cepstrum_end - cepstrum_start)) \ | |||||||||||
| if result > 0 \ | |||||||||||
|
|
|||||||||||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote a response in the document you've shared about why this mapping is more complicated. I can probably write a mapping from
torch.rffttotorch.fft.rfftfor you, if you like, but I'm not sure that's actually what you want to do.torch.fft.rfftreturns a complex tensor, for example, so switching to it doesn't make a lot of sense unless you're also using complex tensors.For now you may just want to suppress the warning that
torch.rfftthrows?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
or you could use
torch.view_as_realat the end, although I agree with @mruberry that it would be kind of wasteful to use thetorch.fft.rfftbefore we migrate audio to start using complex tensors.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry, @anjali411
I looked at your response and it was very helpful, however that document was for general cases, and here we are dealing with the migration of specific use cases, and I think the concerns brought up are irrelevant. Here are the reasons;
The all
torch.rfftusages intorchaudioareonesided=True.This means that we can simply migrate to
torch.fft.rfft, and it is sufficient. We do no need to considertorch.fft.rfftnWe can pass
norm=Nonetotorch.fft.rfft, which is default so in my code change it is omitted.Look at the change in
functional.pyandkaldi.py.rfftis used as a mean of computing the power of the input signal in frequency domain. Therefore, with the appropriate changes (like using.abson complex dtype), no computation is wasted, or no unnecessary computation is introduced.For torchaudio's usecase, the general mapping from
torch.rffttotorch.fft.rfftis not necessary.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this compatibility module is not for migration from
torch.rffttotorch.fft.rfftbut for TorchScript compatibilitytorch.fft.rfft(and others coming up) while avoiding explicit import oftorch.fft.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If there is a function that computes power spectrum (which is a real value) directly, then we do not need to use
torch.fft.rfftbut otherwise, the use of complex dtype here as an intermediate expression makes sense and it is not wasteful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Thanks for the additional context, @mthrok. Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?
Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@mruberry
I gave some thoughts on it, and we could do that but I think having the same signature as
torch.fft.rffthas advantages when it comes to maintainability.rfft, he/she can simply use the same signature astorch.fft.rfft, that will free him/her from having to consider the relationship between our abstraction function and other functions of torchaudio that use this abstraction function. (in short, the person working on the new feature can use this abstraction as a drop-in replacement)torchaudio._internal.fft.rfft->torch.fft.rfft)That is okay for domain libraries. We clearly state that domain libraries are expected/tested to work with the version of PyTorch that is released at the same time. So all the work on master branch expects master version (or the next latest stable release) of PyTorch.