From 8ea3f4d835982315836120a0b682f416eea4ca1c Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 24 Apr 2020 12:09:42 +0530 Subject: [PATCH 1/3] cpp overdrive Signed-off-by: Bhargav Kathivarapu --- setup.py | 8 +++++ torchaudio/functional.py | 6 ++-- torchaudio/torch_overdrive.cpp | 56 ++++++++++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 4 deletions(-) create mode 100644 torchaudio/torch_overdrive.cpp diff --git a/setup.py b/setup.py index 0128a2fe61..c743a8e004 100644 --- a/setup.py +++ b/setup.py @@ -92,6 +92,14 @@ def check_env_flag(name, default=''): extra_compile_args=eca, extra_objects=extra_objects, extra_link_args=ela), + CppExtension( + '_torch_overdrive', + ['torchaudio/torch_overdrive.cpp'], + libraries=libraries, + include_dirs=include_dirs, + extra_compile_args=eca, + extra_objects=extra_objects, + extra_link_args=ela), ] setup( diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 827b9990c1..45f5e67899 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -5,6 +5,7 @@ import torch from torch import Tensor +from _torch_overdrive import _overdrive_float __all__ = [ "istft", @@ -1287,10 +1288,7 @@ def overdrive( output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) # TODO: Implement a torch CPP extension - for i in range(waveform.shape[-1]): - last_out = temp[:, i] - last_in + 0.995 * last_out - last_in = temp[:, i] - output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 + _overdrive_float(waveform,temp,last_in,last_out,output_waveform) return output_waveform.clamp(min=-1, max=1).view(actual_shape) diff --git a/torchaudio/torch_overdrive.cpp b/torchaudio/torch_overdrive.cpp new file mode 100644 index 0000000000..8f2eb80bc3 --- /dev/null +++ b/torchaudio/torch_overdrive.cpp @@ -0,0 +1,56 @@ +#include +// TBD - for CUDA support #include +// TBD - Compile on CUDA + +namespace torch { +namespace audio { + + + template + void _overdrive_float( + at::Tensor & waveform, + at::Tensor & temp, + at::Tensor & last_in, + at::Tensor & last_out, + at::Tensor & output_waveform + ) { + int64_t n_frames = waveform.size(1); + int64_t n_channels = waveform.size(0); + // Create CPU accessors for fast access + // https://pytorch.org/cppdocs/notes/tensor_basics.html + auto waveform_accessor = waveform.accessor(); + auto temp_accessor = temp.accessor(); + auto last_in_accessor = last_in.accessor(); + auto last_out_accessor = last_out.accessor(); + auto output_waveform_accessor = output_waveform.accessor(); + + for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { + + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame]= waveform_accessor[i_channel][i_frame] * 0.5 + last_out_accessor[i_channel] * 0.75; + } + } + /* + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + + last_out = temp.slice(1,i_frame,i_frame+1) - last_in + 0.995 * last_out; + last_in = temp.slice(1,i_frame,i_frame+1); + output_waveform.slice(1,i_frame,i_frame+1) = waveform.slice(1,i_frame,i_frame+1) * 0.5 + last_out * 0.75; + + } + */ + + } + +} // namespace audio +} // namespace torch + + +PYBIND11_MODULE(_torch_overdrive, m) { + m.def( + "_overdrive_float", + &torch::audio::_overdrive_float, + "Executes difference equation with tensor"); +} From fddbdeda99838f2bc65163bde3b0cfac5ca9c575 Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 24 Apr 2020 17:11:16 +0530 Subject: [PATCH 2/3] dynamic dispatch Signed-off-by: Bhargav Kathivarapu --- torchaudio/functional.py | 4 +-- torchaudio/torch_overdrive.cpp | 60 ++++++++++++++++++---------------- 2 files changed, 34 insertions(+), 30 deletions(-) diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 45f5e67899..95b2b98db8 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -5,7 +5,7 @@ import torch from torch import Tensor -from _torch_overdrive import _overdrive_float +from _torch_overdrive import _overdrive_helper __all__ = [ "istft", @@ -1288,7 +1288,7 @@ def overdrive( output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) # TODO: Implement a torch CPP extension - _overdrive_float(waveform,temp,last_in,last_out,output_waveform) + _overdrive_helper(waveform, temp, last_in, last_out, output_waveform) return output_waveform.clamp(min=-1, max=1).view(actual_shape) diff --git a/torchaudio/torch_overdrive.cpp b/torchaudio/torch_overdrive.cpp index 8f2eb80bc3..a23f4045fe 100644 --- a/torchaudio/torch_overdrive.cpp +++ b/torchaudio/torch_overdrive.cpp @@ -1,46 +1,51 @@ #include -// TBD - for CUDA support #include -// TBD - Compile on CUDA namespace torch { namespace audio { - template - void _overdrive_float( - at::Tensor & waveform, - at::Tensor & temp, - at::Tensor & last_in, - at::Tensor & last_out, - at::Tensor & output_waveform + template + void overdrive_cpu_kernel( + at::TensorAccessor waveform_accessor, + at::TensorAccessor temp_accessor, + at::TensorAccessor last_in_accessor, + at::TensorAccessor last_out_accessor, + at::TensorAccessor output_waveform_accessor ) { - int64_t n_frames = waveform.size(1); - int64_t n_channels = waveform.size(0); - // Create CPU accessors for fast access - // https://pytorch.org/cppdocs/notes/tensor_basics.html - auto waveform_accessor = waveform.accessor(); - auto temp_accessor = temp.accessor(); - auto last_in_accessor = last_in.accessor(); - auto last_out_accessor = last_out.accessor(); - auto output_waveform_accessor = output_waveform.accessor(); + int64_t n_frames = waveform_accessor.size(1); + int64_t n_channels = waveform_accessor.size(0); for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; output_waveform_accessor[i_channel][i_frame]= waveform_accessor[i_channel][i_frame] * 0.5 + last_out_accessor[i_channel] * 0.75; + } } - /* - for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { - last_out = temp.slice(1,i_frame,i_frame+1) - last_in + 0.995 * last_out; - last_in = temp.slice(1,i_frame,i_frame+1); - output_waveform.slice(1,i_frame,i_frame+1) = waveform.slice(1,i_frame,i_frame+1) * 0.5 + last_out * 0.75; + } + + + void _overdrive_helper_cpu( + at::Tensor & waveform, + at::Tensor & temp, + at::Tensor & last_in, + at::Tensor & last_out, + at::Tensor & output_waveform + ) { + - } - */ + AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(),"overdrive_cpu",([&]{ + overdrive_cpu_kernel( + waveform.accessor(), + temp.accessor(), + last_in.accessor(), + last_out.accessor(), + output_waveform.accessor()); + } )); } @@ -50,7 +55,6 @@ namespace audio { PYBIND11_MODULE(_torch_overdrive, m) { m.def( - "_overdrive_float", - &torch::audio::_overdrive_float, - "Executes difference equation with tensor"); + "_overdrive_helper",&torch::audio::_overdrive_helper_cpu, + "Executes loop for overdrive effect"); } From 3a29cd666ef2276c0ecca3a5370575d230e7a98e Mon Sep 17 00:00:00 2001 From: Bhargav Kathivarapu Date: Fri, 24 Apr 2020 19:26:51 +0530 Subject: [PATCH 3/3] cpp linting Signed-off-by: Bhargav Kathivarapu --- torchaudio/torch_overdrive.cpp | 86 +++++++++++++++------------------- 1 file changed, 39 insertions(+), 47 deletions(-) diff --git a/torchaudio/torch_overdrive.cpp b/torchaudio/torch_overdrive.cpp index a23f4045fe..d3388977af 100644 --- a/torchaudio/torch_overdrive.cpp +++ b/torchaudio/torch_overdrive.cpp @@ -3,58 +3,50 @@ namespace torch { namespace audio { - - template - void overdrive_cpu_kernel( - at::TensorAccessor waveform_accessor, - at::TensorAccessor temp_accessor, - at::TensorAccessor last_in_accessor, - at::TensorAccessor last_out_accessor, - at::TensorAccessor output_waveform_accessor - ) { - int64_t n_frames = waveform_accessor.size(1); - int64_t n_channels = waveform_accessor.size(0); - - for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { - - for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { - - last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; - last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; - output_waveform_accessor[i_channel][i_frame]= waveform_accessor[i_channel][i_frame] * 0.5 + last_out_accessor[i_channel] * 0.75; - - } +template +void overdrive_cpu_kernel( + at::TensorAccessor waveform_accessor, + at::TensorAccessor temp_accessor, + at::TensorAccessor last_in_accessor, + at::TensorAccessor last_out_accessor, + at::TensorAccessor output_waveform_accessor) { + int64_t n_frames = waveform_accessor.size(1); + int64_t n_channels = waveform_accessor.size(0); + + for (int64_t i_channel = 0; i_channel < n_channels; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - + last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame] = + waveform_accessor[i_channel][i_frame] * 0.5 + + last_out_accessor[i_channel] * 0.75; } - - } - - - void _overdrive_helper_cpu( - at::Tensor & waveform, - at::Tensor & temp, - at::Tensor & last_in, - at::Tensor & last_out, - at::Tensor & output_waveform - ) { - - - AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(),"overdrive_cpu",([&]{ - overdrive_cpu_kernel( - waveform.accessor(), - temp.accessor(), - last_in.accessor(), - last_out.accessor(), - output_waveform.accessor()); - } )); - } +} -} // namespace audio -} // namespace torch +void _overdrive_helper_cpu( + at::Tensor& waveform, + at::Tensor& temp, + at::Tensor& last_in, + at::Tensor& last_out, + at::Tensor& output_waveform) { + AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { + overdrive_cpu_kernel( + waveform.accessor(), + temp.accessor(), + last_in.accessor(), + last_out.accessor(), + output_waveform.accessor()); + })); +} +} // namespace audio +} // namespace torch PYBIND11_MODULE(_torch_overdrive, m) { m.def( - "_overdrive_helper",&torch::audio::_overdrive_helper_cpu, - "Executes loop for overdrive effect"); + "_overdrive_helper", + &torch::audio::_overdrive_helper_cpu, + "Executes helper loop for overdrive effect"); }