diff --git a/torchaudio/csrc/CMakeLists.txt b/torchaudio/csrc/CMakeLists.txt index 1bab67be5a..9d21dc175e 100644 --- a/torchaudio/csrc/CMakeLists.txt +++ b/torchaudio/csrc/CMakeLists.txt @@ -11,6 +11,7 @@ set( sox/effects_chain.cpp sox/types.cpp lfilter.cpp + overdrive.cpp ) if(BUILD_TRANSDUCER) diff --git a/torchaudio/csrc/overdrive.cpp b/torchaudio/csrc/overdrive.cpp new file mode 100644 index 0000000000..4954271e41 --- /dev/null +++ b/torchaudio/csrc/overdrive.cpp @@ -0,0 +1,52 @@ +#include +#include + +namespace { + +template +void overdrive_cpu_kernel( + at::TensorAccessor waveform_accessor, + at::TensorAccessor temp_accessor, + at::TensorAccessor last_in_accessor, + at::TensorAccessor last_out_accessor, + at::TensorAccessor output_waveform_accessor) { + int64_t n_frames = waveform_accessor.size(1); + int64_t n_channels = waveform_accessor.size(0); + + at::parallel_for(0, n_channels, 1, [&](int64_t begin, int64_t end) { + for (int64_t i_channel = begin; i_channel < end; ++i_channel) { + for (int64_t i_frame = 0; i_frame < n_frames; ++i_frame) { + last_out_accessor[i_channel] = temp_accessor[i_channel][i_frame] - + last_in_accessor[i_channel] + 0.995 * last_out_accessor[i_channel]; + last_in_accessor[i_channel] = temp_accessor[i_channel][i_frame]; + output_waveform_accessor[i_channel][i_frame] = + waveform_accessor[i_channel][i_frame] * 0.5 + + last_out_accessor[i_channel] * 0.75; + } + } + }); +} + +void overdrive_core_loop_cpu( + at::Tensor& waveform, + at::Tensor& temp, + at::Tensor& last_in, + at::Tensor& last_out, + at::Tensor& output_waveform) { + AT_DISPATCH_FLOATING_TYPES(waveform.scalar_type(), "overdrive_cpu", ([&] { + overdrive_cpu_kernel( + waveform.accessor(), + temp.accessor(), + last_in.accessor(), + last_out.accessor(), + output_waveform.accessor()); + })); +} + +} // namespace + +// Note: We want to avoid using "catch-all" kernel. +// The following registration should be replaced with CPU specific registration. +TORCH_LIBRARY_FRAGMENT(torchaudio, m) { + m.def("torchaudio::_overdrive_core_loop", &overdrive_core_loop_cpu); +} diff --git a/torchaudio/functional/filtering.py b/torchaudio/functional/filtering.py index 3c27a0d617..94b98ac7fc 100644 --- a/torchaudio/functional/filtering.py +++ b/torchaudio/functional/filtering.py @@ -939,6 +939,26 @@ def lowpass_biquad( return biquad(waveform, b0, b1, b2, a0, a1, a2) +def _overdrive_core_loop_generic( + waveform: Tensor, + temp: Tensor, + last_in: Tensor, + last_out: Tensor, + output_waveform: Tensor +): + for i in range(waveform.shape[-1]): + last_out = temp[:, i] - last_in + 0.995 * last_out + last_in = temp[:, i] + output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 + + +try: + _overdrive_core_loop_cpu = torch.ops.torchaudio._overdrive_core_loop +except RuntimeError as err: + assert str(err) == 'No such operator torchaudio::_overdrive_core_loop' + _overdrive_core_loop_cpu = _overdrive_core_loop_generic + + def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: r"""Apply a overdrive effect to the audio. Similar to SoX implementation. This effect applies a non linear distortion to the audio signal. @@ -981,11 +1001,11 @@ def overdrive(waveform: Tensor, gain: float = 20, colour: float = 20) -> Tensor: output_waveform = torch.zeros_like(waveform, dtype=dtype, device=device) - # TODO: Implement a torch CPP extension - for i in range(waveform.shape[-1]): - last_out = temp[:, i] - last_in + 0.995 * last_out - last_in = temp[:, i] - output_waveform[:, i] = waveform[:, i] * 0.5 + last_out * 0.75 + # Uses CPU optimized loop function if available for CPU device + if device == torch.device('cpu'): + _overdrive_core_loop_cpu(waveform, temp, last_in, last_out, output_waveform) + else: + _overdrive_core_loop_generic(waveform, temp, last_in, last_out, output_waveform) return output_waveform.clamp(min=-1, max=1).view(actual_shape)