Skip to content

Conversation

@mthrok
Copy link
Contributor

@mthrok mthrok commented Oct 8, 2020

Update torch.rfft (deprecated) to use the equivalent of torch.fft.rfft but without importing torch.fft.
Also, update power computation of complex value tensor to use complex type tensor.

@mthrok mthrok requested review from anjali411 and mruberry October 8, 2020 23:13

def rfft(input: torch.Tensor, n: Optional[int] = None, dim: int = -1, norm: Optional[str] = None) -> torch.Tensor:
# see: https://pytorch.org/docs/master/fft.html#torch.fft.rfft
return torch._C._fft.fft_rfft(input, n, dim, norm)
Copy link
Contributor

@mruberry mruberry Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote a response in the document you've shared about why this mapping is more complicated. I can probably write a mapping from torch.rfft to torch.fft.rfft for you, if you like, but I'm not sure that's actually what you want to do. torch.fft.rfft returns a complex tensor, for example, so switching to it doesn't make a lot of sense unless you're also using complex tensors.

For now you may just want to suppress the warning that torch.rfft throws?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or you could use torch.view_as_real at the end, although I agree with @mruberry that it would be kind of wasteful to use the torch.fft.rfft before we migrate audio to start using complex tensors.

Copy link
Contributor Author

@mthrok mthrok Oct 9, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry, @anjali411

I wrote a response in the document you've shared about why this mapping is more complicated.

I looked at your response and it was very helpful, however that document was for general cases, and here we are dealing with the migration of specific use cases, and I think the concerns brought up are irrelevant. Here are the reasons;

The all torch.rfft usages in torchaudio are

  1. Inputs are always 2D and onesided=True.
    This means that we can simply migrate to torch.fft.rfft, and it is sufficient. We do no need to consider torch.fft.rfftn
  2. No normalization is required.
    We can pass norm=None to torch.fft.rfft, which is default so in my code change it is omitted.
  3. The complex values are immediately used to compute power.
    Look at the change in functional.py and kaldi.py. rfft is used as a mean of computing the power of the input signal in frequency domain. Therefore, with the appropriate changes (like using .abs on complex dtype), no computation is wasted, or no unnecessary computation is introduced.

I can probably write a mapping from torch.rfft to torch.fft.rfft

For torchaudio's usecase, the general mapping from torch.rfft to torch.fft.rfft is not necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this compatibility module is not for migration from torch.rfft to torch.fft.rfft but for TorchScript compatibility torch.fft.rfft (and others coming up) while avoiding explicit import of torch.fft.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is a function that computes power spectrum (which is a real value) directly, then we do not need to use torch.fft.rfft but otherwise, the use of complex dtype here as an intermediate expression makes sense and it is not wasteful.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Thanks for the additional context, @mthrok. Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?

Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mruberry

Would it make sense, then, to limit the signature of this function or add a comment explaining that it only does a limited translation from torch.rfft to torch.fft.rfft?

I gave some thoughts on it, and we could do that but I think having the same signature as torch.fft.rfft has advantages when it comes to maintainability.

  1. If someone wants to add a new functionality that uses rfft, he/she can simply use the same signature as torch.fft.rfft, that will free him/her from having to consider the relationship between our abstraction function and other functions of torchaudio that use this abstraction function. (in short, the person working on the new feature can use this abstraction as a drop-in replacement)
  2. similar to 1., in future when we PyTorch is done with the immigration, we can simply replace the abstraction function path with the actual PyTorch implementation. (torchaudio._internal.fft.rfft -> torch.fft.rfft)

Also, I'm not sure how many versions of PyTorch torchaudio supports, but torch.fft.rfft will only be available in PyTorch 1.7+, so previous versions of PyTorch will still need to call torch.rfft.

That is okay for domain libraries. We clearly state that domain libraries are expected/tested to work with the version of PyTorch that is released at the same time. So all the work on master branch expects master version (or the next latest stable release) of PyTorch.

if boot_count >= 0 \
else measure_smooth_time_mult

_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd keep the change norm/abs separate. In a prior attempt, there was a performance regression, #747.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vincentqb Can you elaborate why the regression in torch.norm, which is not used in this PR, is the reason to discourage the use of torch.abs? Are they using the same implementation under the hood?

Copy link
Contributor

@vincentqb vincentqb Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll let @anjali411 @mruberry comment on how torch.abs and torch.norm are related if they are. If they are we would end up with a performance regression again. We could add some performance tests/checks manually or automatically to catch performance changes. However, my suggestion is simply to decouple the changes about torch.*fft* from those about the use of .abs() so that, in the event we do get a regression, we can easily track it and revert it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datapoint: torch.abs and torch.norm have separate implementations.

Copy link

@anjali411 anjali411 Oct 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be perfectly fine to use torch.abs() here. complex_norm is not even using torch.norm https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), so torch.abs() should be a strict improvement here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datapoint: torch.abs and torch.norm have separate implementations.

Good to know :)

I think it should be perfectly fine to use torch.abs() here. complex_norm is not even using torch.norm https://github.com/pytorch/audio/blob/master/torchaudio/functional.py#L424. In fact, it's using three kernels (pow, sum, pow), so torch.abs() should be a strict improvement here.

We had complex_norm using torch.norm, but this led to a speed regression so we revert to the current implementation with three kernels. Because of this, I want to avoid a performance regression again. I'm glad to know this would launch only one kernel: has the performance difference been tested and compared in this case?

In any case, let's just separate concerns and move this to a separate pull request so we don't block the rest on this :)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

x=torch.randn(10, 20, 2)
In [7]: def fn(x):
   ...:     t0=time.time()
   ...:     o=x.pow(2.0).sum(-1).pow(0.5)
   ...:     t1 = time.time()
   ...:     print(t1-t0)
   ...:

In [8]: fn(x)
0.00037217140197753906

In [10]: y=torch.view_as_complex(x).contiguous()

In [11]: fn(y)
0.0001480579376220703

There's a significant performance gain (as expected) so I think we should switch to torch.abs.

Copy link
Contributor Author

@mthrok mthrok Oct 30, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @anjali411 !

In addition to that I did benchmark for the exact code path that @vincentqb suggested and observed that abs is 10x faster on CPU and 2x faster on GPU.

Method CPU GPU
t.abs() 100 loops, best of 5: 1.34 msec per loop 100 loops, best of 5: 28.2 usec per loop
complex_norm(view_as_real(t)) 100 loops, best of 5: 14.9 msec per loop 100 loops, best of 5: 60.5 usec per loop

PyTorch: 1.8.0a0+edac406
torchaudio: 0.8.0a0+0076ab0

code
OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
t.abs();
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100));
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
t.abs();
"""

OMP_NUM_THREADS=1 numactl --membind 0 --cpubind 0 python -m timeit -n 100 -r 5 -s """
import torch;
import torch.fft;
import torchaudio.functional;
torch.manual_seed(0);
t = torch.fft.rfft(torch.randn(1, 32*44100)).to('cuda');
""" """
torchaudio.functional.complex_norm(torch.view_as_real(t));
"""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for checking!

fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
fft = torchaudio._internal.fft.rfft(strided_input)

power_spectrum = fft.pow(2).sum(2).unsqueeze(1) # size (m, 1, padded_window_size // 2 + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

complex_norm(
_cepstrum_Buf[cepstrum_start:cepstrum_end],
power=2.0)))
result: float = float(torch.sum(_cepstrum_Buf[cepstrum_start:cepstrum_end].abs().pow(2)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same


# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.pow(2).sum(2), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum = torch.max(fft.abs().pow(2.), epsilon).log() # size (m, padded_window_size // 2 + 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same

@mthrok mthrok force-pushed the rfft-migrate branch 2 times, most recently from 39edd48 to ac02f9f Compare October 27, 2020 17:58
Copy link

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

complex number stuff looks good

Copy link
Contributor

@vincentqb vincentqb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Any reason why this is marked as draft?

@mthrok mthrok marked this pull request as ready for review November 5, 2020 17:15
@mthrok mthrok merged commit 48d2b57 into pytorch:master Nov 5, 2020
@mthrok mthrok deleted the rfft-migrate branch November 5, 2020 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants