Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,14 @@ def check_env_flag(name, default=''):
extra_compile_args=eca,
extra_objects=extra_objects,
extra_link_args=ela),
CppExtension(
'_torch_overdrive',
['torchaudio/torch_overdrive.cpp'],
libraries=libraries,
include_dirs=include_dirs,
extra_compile_args=eca,
extra_objects=extra_objects,
extra_link_args=ela),
]

setup(
Expand Down
6 changes: 2 additions & 4 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import torch
from torch import Tensor
from _torch_overdrive import _overdrive_helper

__all__ = [
"istft",
Expand Down Expand Up @@ -1287,10 +1288,7 @@ def overdrive(
output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device)

# TODO: Implement a torch CPP extension
for i in range(waveform.shape[-1]):
last_out = temp[:, i] - last_in + 0.995 * last_out
last_in = temp[:, i]
output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75
_overdrive_helper(waveform, temp, last_in, last_out, output_waveform)

return output_waveform.clamp(min=-1, max=1).view(actual_shape)

Expand Down
52 changes: 52 additions & 0 deletions torchaudio/torch_overdrive.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include <torch/extension.h>

namespace torch {
namespace audio {

template <typename scalar_t>
void overdrive_cpu_kernel(
at::TensorAccessor<scalar_t, 2> waveform_accessor,
at::TensorAccessor<scalar_t, 2> temp_accessor,
at::TensorAccessor<scalar_t, 1> last_in_accessor,
at::TensorAccessor<scalar_t, 1> last_out_accessor,
at::TensorAccessor<scalar_t, 2> output_waveform_accessor) {
int64_t n_frames = waveform_accessor.size(1);
int64_t n_channels = waveform_accessor.size(0);

for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the amount of work you might benefit from using parallel_for.

Most PyTorch CPU operators are parallelized, unless there's no obvious need due to memory-boundedness.

Another issue with pure C for C++ extensions, for now, is autovectorization. We can't ship avx2 code without a CPU capability based dispatch. That means for C code in extensions like this we're for now restricted to SSE and related.

Of course this is taken care of when you call into at:: operations directly, since they each take advantage being part of libtorch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cpuhrsch , yeah parallelization can be applied only for the channels loop . I was not sure how the parallel_for treats the inner sequential loop , so kept it without the parallel_for. A parallel thread won't interfere with other parallel thread's inner loop working right ?

Copy link
Contributor

@cpuhrsch cpuhrsch May 12, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bhargavkathivarapu - I'm not sure what you mean by "interfere with" exactly. As in, shared variables or creating integers etc.? In this particular case it seems that the inner loops are independent of each other given that they differ in i_channel. The pointers and such will still be picked up as shared variables, but as long as you don't write to a single memory location from multiple threads concurrently etc., there's no issue.

By default PyTorch uses openmp which yields this implementation. Look into openmp's omp parallel (here is what looks like a good explanation) for some more detail on what that means.

for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) {
last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] -
last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel];
last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're setting the value of last_in to the value of temp for the current iteration so that the next iteration those values may be used . But instead you could just read from temp all the time (except for the first iteration) right? I added a similar comment for the Python code above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But instead you could just read from temp all the time (except for the first iteration) right?

Yes , the first iteration needs to be handled then we can remove the last_in variable

output_waveform_accessor[i_channel][i_frame] =
waveform_accessor[i_channel][i_frame] * 0.5 +
last_out_accessor[i_channel] * 0.75;
}
}
}

void _overdrive_helper_cpu(
at::Tensor& waveform,
at::Tensor& temp,
at::Tensor& last_in,
at::Tensor& last_out,
at::Tensor& output_waveform) {
AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] {
overdrive_cpu_kernel<scalar_t>(
waveform.accessor<scalar_t, 2>(),
temp.accessor<scalar_t, 2>(),
last_in.accessor<scalar_t, 1>(),
last_out.accessor<scalar_t, 1>(),
output_waveform.accessor<scalar_t, 2>());
}));
}

} // namespace audio
} // namespace torch

PYBIND11_MODULE(_torch_overdrive, m) {
m.def(
"_overdrive_helper",
&torch::audio::_overdrive_helper_cpu,
"Executes helper loop for overdrive effect");
}