diff --git a/setup.py b/setup.py index 0128a2fe61..c743a8e004 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,14 @@ def check_env_flag(name, default=''): extra_compile_args=eca, extra_objects=extra_objects, extra_link_args=ela), + CppExtension( + '_torch_overdrive', + ['torchaudio/torch_overdrive.cpp'], + libraries=libraries, + include_dirs=include_dirs, + extra_compile_args=eca, + extra_objects=extra_objects, + extra_link_args=ela), ] setup( diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 827b9990c1..95b2b98db8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -5,6 +5,7 @@ import torch from torch import Tensor +from _torch_overdrive import _overdrive_helper __all__ = [ "istft", @@ -1287,10 +1288,7 @@ def overdrive( output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) # TODO: Implement a torch CPP extension - for i in range(waveform.shape[-1]): - last_out = temp[:, i] - last_in + 0.995 * last_out - last_in = temp[:, i] - output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 + _overdrive_helper(waveform, temp, last_in, last_out, output_waveform) return output_waveform.clamp(min=-1, max=1).view(actual_shape) diff --git a/torchaudio/torch_overdrive.cpp b/torchaudio/torch_overdrive.cpp new file mode 100644 index 0000000000..d3388977af --- /dev/null +++ b/torchaudio/torch_overdrive.cpp @@ -0,0 +1,52 @@ +#include + +namespace torch { +namespace audio { + +template +void overdrive_cpu_kernel( + at::TensorAccessor waveform_accessor, + at::TensorAccessor temp_accessor, + at::TensorAccessor last_in_accessor, + at::TensorAccessor last_out_accessor, + at::TensorAccessor output_waveform_accessor) { + int64_t n_frames = waveform_accessor.size(1); + int64_t n_channels = waveform_accessor.size(0); + + for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - + last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame] = + waveform_accessor[i_channel][i_frame] * 0.5 + + last_out_accessor[i_channel] * 0.75; + } + } +} + +void _overdrive_helper_cpu( + at::Tensor& waveform, + at::Tensor& temp, + at::Tensor& last_in, + at::Tensor& last_out, + at::Tensor& output_waveform) { + AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { + overdrive_cpu_kernel( + waveform.accessor(), + temp.accessor(), + last_in.accessor(), + last_out.accessor(), + output_waveform.accessor()); + })); +} + +} // namespace audio +} // namespace torch + +PYBIND11_MODULE(_torch_overdrive, m) { + m.def( + "_overdrive_helper", + &torch::audio::_overdrive_helper_cpu, + "Executes helper loop for overdrive effect"); +}