diff --git a/docs/source/functional.rst b/docs/source/functional.rst index a8d57bb36c..9968bfadf7 100644 --- a/docs/source/functional.rst +++ b/docs/source/functional.rst @@ -62,3 +62,23 @@ Functions to perform common audio operations. ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: phase_vocoder + +:hidden:`lfilter` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: lfilter + +:hidden:`biquad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: biquad + +:hidden:`lowpass_biquad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: lowpass_biquad + +:hidden:`highpass_biquad` +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autofunction:: highpass_biquad diff --git a/test/assets/dtmf_30s_stereo.mp3 b/test/assets/dtmf_30s_stereo.mp3 new file mode 100644 index 0000000000..6c97835ec0 Binary files /dev/null and b/test/assets/dtmf_30s_stereo.mp3 differ diff --git a/test/assets/whitenoise.mp3 b/test/assets/whitenoise.mp3 new file mode 100644 index 0000000000..d6fe9f44b9 Binary files /dev/null and b/test/assets/whitenoise.mp3 differ diff --git a/test/assets/whitenoise_1min.mp3 b/test/assets/whitenoise_1min.mp3 new file mode 100644 index 0000000000..07b1198820 Binary files /dev/null and b/test/assets/whitenoise_1min.mp3 differ diff --git a/test/test_datasets_vctk.py b/test/test_datasets_vctk.py index 4d3477c338..d650225fe0 100644 --- a/test/test_datasets_vctk.py +++ b/test/test_datasets_vctk.py @@ -23,10 +23,13 @@ def test_is_audio_file(self): def test_make_manifest(self): audios = vctk.make_manifest(self.test_dirpath) files = ['kaldi_file.wav', 'kaldi_file_8000.wav', - 'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3'] + 'sinewave.wav', 'steam-train-whistle-daniel_simon.mp3', + 'dtmf_30s_stereo.mp3', 'whitenoise_1min.mp3', 'whitenoise.mp3'] files = [self.get_full_path(file) for file in files] + files.sort() audios.sort() + self.assertEqual(files, audios, msg='files %s did not match audios %s' % (files, audios)) def test_read_audio_downsample_false(self): diff --git a/test/test_functional_filtering.py b/test/test_functional_filtering.py new file mode 100644 index 0000000000..e4ec4f143d --- /dev/null +++ b/test/test_functional_filtering.py @@ -0,0 +1,148 @@ +from __future__ import absolute_import, division, print_function, unicode_literals +import math +import os +import torch +import torchaudio +import torchaudio.functional as F +import unittest +import common_utils +import time + + +class TestFunctionalFiltering(unittest.TestCase): + test_dirpath, test_dir = common_utils.create_temp_assets_dir() + + def test_lfilter_basic(self): + """ + Create a very basic signal, + Then make a simple 4th order delay + The output should be same as the input but shifted + """ + + torch.random.manual_seed(42) + waveform = torch.rand(2, 44100 * 10) + b_coeffs = torch.tensor([0, 0, 0, 1], dtype=torch.float32) + a_coeffs = torch.tensor([1, 0, 0, 0], dtype=torch.float32) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + + assert torch.allclose( + waveform[:, 0:-3], output_waveform[:, 3:], atol=1e-5 + ) + + def test_lfilter(self): + """ + Design an IIR lowpass filter using scipy.signal filter design + https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign + + Example + >>> from scipy.signal import iirdesign + >>> b, a = iirdesign(0.2, 0.3, 1, 60) + """ + + b_coeffs = torch.tensor( + [ + 0.00299893, + -0.0051152, + 0.00841964, + -0.00747802, + 0.00841964, + -0.0051152, + 0.00299893, + ] + ) + a_coeffs = torch.tensor( + [ + 1.0, + -4.8155751, + 10.2217618, + -12.14481273, + 8.49018171, + -3.3066882, + 0.56088705, + ] + ) + + filepath = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3") + waveform, sample_rate = torchaudio.load(filepath, normalization=True) + output_waveform = F.lfilter(waveform, a_coeffs, b_coeffs) + assert len(output_waveform.size()) == 2 + assert output_waveform.size(0) == waveform.size(0) + assert output_waveform.size(1) == waveform.size(1) + + def test_lowpass(self): + + """ + Test biquad lowpass filter, compare to SoX implementation + """ + + CUTOFF_FREQ = 3000 + + noise_filepath = os.path.join( + self.test_dirpath, "assets", "whitenoise.mp3" + ) + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(noise_filepath) + E.append_effect_to_chain("lowpass", [CUTOFF_FREQ]) + sox_output_waveform, sr = E.sox_build_flow_effects() + + waveform, sample_rate = torchaudio.load( + noise_filepath, normalization=True + ) + output_waveform = F.lowpass_biquad(waveform, sample_rate, CUTOFF_FREQ) + + assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-4) + + def test_highpass(self): + """ + Test biquad highpass filter, compare to SoX implementation + """ + + CUTOFF_FREQ = 2000 + + noise_filepath = os.path.join( + self.test_dirpath, "assets", "whitenoise.mp3" + ) + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(noise_filepath) + E.append_effect_to_chain("highpass", [CUTOFF_FREQ]) + sox_output_waveform, sr = E.sox_build_flow_effects() + + waveform, sample_rate = torchaudio.load( + noise_filepath, normalization=True + ) + output_waveform = F.highpass_biquad(waveform, sample_rate, CUTOFF_FREQ) + + # TBD - this fails at the 1e-4 level, debug why + assert torch.allclose(sox_output_waveform, output_waveform, atol=1e-3) + + def test_perf_biquad_filtering(self): + + fn_sine = os.path.join(self.test_dirpath, "assets", "whitenoise.mp3") + + b0 = 0.4 + b1 = 0.2 + b2 = 0.9 + a0 = 0.7 + a1 = 0.2 + a2 = 0.6 + + # SoX method + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(fn_sine) + _timing_sox = time.time() + E.append_effect_to_chain("biquad", [b0, b1, b2, a0, a1, a2]) + waveform_sox_out, sr = E.sox_build_flow_effects() + _timing_sox_run_time = time.time() - _timing_sox + + _timing_lfilter_filtering = time.time() + waveform, sample_rate = torchaudio.load(fn_sine, normalization=True) + waveform_lfilter_out = F.lfilter( + waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) + ) + _timing_lfilter_run_time = time.time() - _timing_lfilter_filtering + + assert torch.allclose(waveform_sox_out, waveform_lfilter_out, atol=1e-4) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchaudio/functional.py b/torchaudio/functional.py index e33cc702c7..842d8495bf 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -2,40 +2,63 @@ import math import torch - __all__ = [ - 'istft', - 'spectrogram', - 'amplitude_to_DB', - 'create_fb_matrix', - 'create_dct', - 'mu_law_encoding', - 'mu_law_decoding', - 'complex_norm', - 'angle', - 'magphase', - 'phase_vocoder', + "istft", + "spectrogram", + "amplitude_to_DB", + "create_fb_matrix", + "create_dct", + "mu_law_encoding", + "mu_law_decoding", + "complex_norm", + "angle", + "magphase", + "phase_vocoder", + "lfilter", + "lowpass_biquad", + "highpass_biquad", + "biquad", ] - # TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved @torch.jit.ignore -def _stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided): +def _stft( + waveform, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, +): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor - return torch.stft(waveform, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided) - - -def istft(stft_matrix, # type: Tensor - n_fft, # type: int - hop_length=None, # type: Optional[int] - win_length=None, # type: Optional[int] - window=None, # type: Optional[Tensor] - center=True, # type: bool - pad_mode='reflect', # type: str - normalized=False, # type: bool - onesided=True, # type: bool - length=None # type: Optional[int] - ): + return torch.stft( + waveform, + n_fft, + hop_length, + win_length, + window, + center, + pad_mode, + normalized, + onesided, + ) + + +def istft( + stft_matrix, # type: Tensor + n_fft, # type: int + hop_length=None, # type: Optional[int] + win_length=None, # type: Optional[int] + window=None, # type: Optional[Tensor] + center=True, # type: bool + pad_mode="reflect", # type: str + normalized=False, # type: bool + onesided=True, # type: bool + length=None, # type: Optional[int] +): # type: (...) -> Tensor r"""Inverse short time Fourier Transform. This is expected to be the inverse of torch.stft. It has the same parameters (+ additional optional parameter of ``length``) and it should return the @@ -90,7 +113,7 @@ def istft(stft_matrix, # type: Tensor (channel, signal_length) or (signal_length) """ stft_matrix_dim = stft_matrix.dim() - assert 3 <= stft_matrix_dim <= 4, ('Incorrect stft dimension: %d' % (stft_matrix_dim)) + assert 3 <= stft_matrix_dim <= 4, "Incorrect stft dimension: %d" % (stft_matrix_dim) if stft_matrix_dim == 3: # add a channel dimension @@ -99,9 +122,13 @@ def istft(stft_matrix, # type: Tensor dtype = stft_matrix.dtype device = stft_matrix.device fft_size = stft_matrix.size(1) - assert (onesided and n_fft // 2 + 1 == fft_size) or (not onesided and n_fft == fft_size), ( - 'one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. ' + - 'Given values were onesided: %s, n_fft: %d, fft_size: %d' % ('True' if onesided else False, n_fft, fft_size)) + assert (onesided and n_fft // 2 + 1 == fft_size) or ( + not onesided and n_fft == fft_size + ), ( + "one_sided implies that n_fft // 2 + 1 == fft_size and not one_sided implies n_fft == fft_size. " + + "Given values were onesided: %s, n_fft: %d, fft_size: %d" + % ("True" if onesided else False, n_fft, fft_size) + ) # use stft defaults for Optionals if win_length is None: @@ -127,8 +154,9 @@ def istft(stft_matrix, # type: Tensor # win_length and n_fft are synonymous from here on stft_matrix = stft_matrix.transpose(1, 2) # size (channel, n_frames, fft_size, 2) - stft_matrix = torch.irfft(stft_matrix, 1, normalized, - onesided, signal_sizes=(n_fft,)) # size (channel, n_frames, n_fft) + stft_matrix = torch.irfft( + stft_matrix, 1, normalized, onesided, signal_sizes=(n_fft,) + ) # size (channel, n_frames, n_fft) assert stft_matrix.size(2) == n_fft n_frames = stft_matrix.size(1) @@ -137,18 +165,23 @@ def istft(stft_matrix, # type: Tensor # each column of a channel is a frame which needs to be overlap added at the right place ytmp = ytmp.transpose(1, 2) # size (channel, n_fft, n_frames) - eye = torch.eye(n_fft, requires_grad=False, - device=device, dtype=dtype).unsqueeze(1) # size (n_fft, 1, n_fft) + eye = torch.eye(n_fft, requires_grad=False, device=device, dtype=dtype).unsqueeze( + 1 + ) # size (n_fft, 1, n_fft) # this does overlap add where the frames of ytmp are added such that the i'th frame of # ytmp is added starting at i*hop_length in the output y = torch.nn.functional.conv_transpose1d( - ytmp, eye, stride=hop_length, padding=0) # size (channel, 1, expected_signal_len) + ytmp, eye, stride=hop_length, padding=0 + ) # size (channel, 1, expected_signal_len) # do the same for the window function - window_sq = window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) # size (1, n_fft, n_frames) + window_sq = ( + window.pow(2).view(n_fft, 1).repeat((1, n_frames)).unsqueeze(0) + ) # size (1, n_fft, n_frames) window_envelop = torch.nn.functional.conv_transpose1d( - window_sq, eye, stride=hop_length, padding=0) # size (1, 1, expected_signal_len) + window_sq, eye, stride=hop_length, padding=0 + ) # size (1, 1, expected_signal_len) expected_signal_len = n_fft + hop_length * (n_frames - 1) assert y.size(2) == expected_signal_len @@ -164,7 +197,9 @@ def istft(stft_matrix, # type: Tensor # check NOLA non-zero overlap condition window_envelop_lowest = window_envelop.abs().min() - assert window_envelop_lowest > 1e-11, ('window overlap add min: %f' % (window_envelop_lowest)) + assert window_envelop_lowest > 1e-11, "window overlap add min: %f" % ( + window_envelop_lowest + ) y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len) @@ -174,7 +209,9 @@ def istft(stft_matrix, # type: Tensor @torch.jit.script -def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, normalized): +def spectrogram( + waveform, pad, window, n_fft, hop_length, win_length, power, normalized +): # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor r"""Create a spectrogram from a raw audio signal. @@ -201,8 +238,9 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor waveform = torch.nn.functional.pad(waveform, (pad, pad), "constant") # default values are consistent with librosa.core.spectrum._spectrogram - spec_f = _stft(waveform, n_fft, hop_length, win_length, window, - True, 'reflect', False, True) + spec_f = _stft( + waveform, n_fft, hop_length, win_length, window, True, "reflect", False, True + ) if normalized: spec_f /= window.pow(2).sum().sqrt() @@ -234,8 +272,9 @@ def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None): x_db -= multiplier * db_multiplier if top_db is not None: - new_x_db_max = torch.tensor(float(x_db.max()) - top_db, - dtype=x_db.dtype, device=x_db.device) + new_x_db_max = torch.tensor( + float(x_db.max()) - top_db, dtype=x_db.dtype, device=x_db.device + ) x_db = torch.max(x_db, new_x_db_max) return x_db @@ -263,17 +302,17 @@ def create_fb_matrix(n_freqs, f_min, f_max, n_mels): freqs = torch.linspace(f_min, f_max, n_freqs) # calculate mel freq bins # hertz to mel(f) is 2595. * math.log10(1. + (f / 700.)) - m_min = 0. if f_min == 0 else 2595. * math.log10(1. + (f_min / 700.)) - m_max = 2595. * math.log10(1. + (f_max / 700.)) + m_min = 0.0 if f_min == 0 else 2595.0 * math.log10(1.0 + (f_min / 700.0)) + m_max = 2595.0 * math.log10(1.0 + (f_max / 700.0)) m_pts = torch.linspace(m_min, m_max, n_mels + 2) # mel to hertz(mel) is 700. * (10**(mel / 2595.) - 1.) - f_pts = 700. * (10**(m_pts / 2595.) - 1.) + f_pts = 700.0 * (10 ** (m_pts / 2595.0) - 1.0) # calculate the difference between each mel point and each stft freq point in hertz f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1) slopes = f_pts.unsqueeze(0) - freqs.unsqueeze(1) # (n_freqs, n_mels + 2) # create overlapping triangles zero = torch.zeros(1) - down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) + down_slopes = (-1.0 * slopes[:, :-2]) / f_diff[:-1] # (n_freqs, n_mels) up_slopes = slopes[:, 2:] / f_diff[1:] # (n_freqs, n_mels) fb = torch.max(zero, torch.min(down_slopes, up_slopes)) return fb @@ -301,7 +340,7 @@ def create_dct(n_mfcc, n_mels, norm): if norm is None: dct *= 2.0 else: - assert norm == 'ortho' + assert norm == "ortho" dct[0] *= 1.0 / math.sqrt(2.0) dct *= math.sqrt(2.0 / float(n_mels)) return dct.t() @@ -323,12 +362,11 @@ def mu_law_encoding(x, quantization_channels): Returns: torch.Tensor: Input after mu-law encoding """ - mu = quantization_channels - 1. + mu = quantization_channels - 1.0 if not x.is_floating_point(): x = x.to(torch.float) mu = torch.tensor(mu, dtype=x.dtype) - x_mu = torch.sign(x) * torch.log1p(mu * - torch.abs(x)) / torch.log1p(mu) + x_mu = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) x_mu = ((x_mu + 1) / 2 * mu + 0.5).to(torch.int64) return x_mu @@ -349,12 +387,12 @@ def mu_law_decoding(x_mu, quantization_channels): Returns: torch.Tensor: Input after mu-law decoding """ - mu = quantization_channels - 1. + mu = quantization_channels - 1.0 if not x_mu.is_floating_point(): x_mu = x_mu.to(torch.float) mu = torch.tensor(mu, dtype=x_mu.dtype) - x = ((x_mu) / mu) * 2 - 1. - x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu + x = ((x_mu) / mu) * 2 - 1.0 + x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu return x @@ -385,7 +423,7 @@ def angle(complex_tensor): return torch.atan2(complex_tensor[..., 1], complex_tensor[..., 0]) -def magphase(complex_tensor, power=1.): +def magphase(complex_tensor, power=1.0): r"""Separate a complex-valued spectrogram with shape `(*, 2)` into its magnitude and phase. Args: @@ -428,13 +466,15 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ndim = complex_specgrams.dim() time_slice = [slice(None)] * (ndim - 2) - time_steps = torch.arange(0, - complex_specgrams.size(-2), - rate, - device=complex_specgrams.device, - dtype=complex_specgrams.dtype) + time_steps = torch.arange( + 0, + complex_specgrams.size(-2), + rate, + device=complex_specgrams.device, + dtype=complex_specgrams.dtype, + ) - alphas = time_steps % 1. + alphas = time_steps % 1.0 phase_0 = angle(complex_specgrams[time_slice + [slice(1)]]) # Time Padding @@ -466,3 +506,149 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) return complex_specgrams_stretch + + +def lfilter(waveform, a_coeffs, b_coeffs): + # type: (Tensor, Tensor, Tensor) -> Tensor + r""" + Performs an IIR filter by evaluating difference equation. + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)`. Must be normalized to -1 to 1. + a_coeffs (torch.Tensor): denominator coefficients of difference equation of dimension of `(n_order + 1)`. + Lower delays coefficients are first, e.g. `[a0, a1, a2, ...]`. + Must be same size as b_coeffs (pad with 0's as necessary). + b_coeffs (torch.Tensor): numerator coefficients of difference equation of dimension of `(n_order + 1)`. + Lower delays coefficients are first, e.g. `[b0, b1, b2, ...]`. + Must be same size as a_coeffs (pad with 0's as necessary). + + Returns: + output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)`. Output will be clipped to -1 to 1. + + """ + + assert(waveform.dtype == torch.float32) + assert(a_coeffs.size(0) == b_coeffs.size(0)) + assert(len(waveform.size()) == 2) + + n_channels, n_frames = waveform.size() + n_order = a_coeffs.size(0) + assert(n_order > 0) + + # Pad the input and create output + padded_waveform = torch.zeros(n_channels, n_frames + n_order - 1) + padded_waveform[:, (n_order - 1):] = waveform + padded_output_waveform = torch.zeros(n_channels, n_frames + n_order - 1) + + # Set up the coefficients matrix + # Flip order, repeat, and transpose + a_coeffs_filled = a_coeffs.flip(0).repeat(n_channels, 1).t() + b_coeffs_filled = b_coeffs.flip(0).repeat(n_channels, 1).t() + + # Set up a few other utilities + a0_repeated = torch.ones(n_channels) * a_coeffs[0] + ones = torch.ones(n_channels, n_frames) + + for i_frame in range(n_frames): + + o0 = torch.zeros(n_channels) + + windowed_input_signal = padded_waveform[:, i_frame:(i_frame + n_order)] + windowed_output_signal = padded_output_waveform[:, i_frame:(i_frame + n_order)] + + o0.add_(torch.diag(torch.mm(windowed_input_signal, b_coeffs_filled))) + o0.sub_(torch.diag(torch.mm(windowed_output_signal, a_coeffs_filled))) + + o0.div_(a0_repeated) + + padded_output_waveform[:, i_frame + n_order - 1] = o0 + + return torch.min(ones, torch.max(ones * -1, padded_output_waveform[:, (n_order - 1):])) + + +def biquad(waveform, b0, b1, b2, a0, a1, a2): + # type: (Tensor, float, float, float, float, float, float) -> Tensor + r"""Performs a biquad filter of input tensor. Initial conditions set to 0. + https://en.wikipedia.org/wiki/Digital_biquad_filter + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)` + b0 (float): numerator coefficient of current input, x[n] + b1 (float): numerator coefficient of input one time step ago x[n-1] + b2 (float): numerator coefficient of input two time steps ago x[n-2] + a0 (float): denominator coefficient of current output y[n], typically 1 + a1 (float): denominator coefficient of current output y[n-1] + a2 (float): denominator coefficient of current output y[n-2] + + Returns: + output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)` + """ + + assert(waveform.dtype == torch.float32) + + output_waveform = lfilter( + waveform, torch.tensor([a0, a1, a2]), torch.tensor([b0, b1, b2]) + ) + return output_waveform + + +def _dB2Linear(x): + return math.exp(x * math.log(10) / 20.0) + + +def highpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): + # type: (Tensor, int, float, Optional[float]) -> Tensor + r"""Designs biquad highpass filter and performs filtering. Similar to SoX implementation. + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float): filter cutoff frequency + Q (float): https://en.wikipedia.org/wiki/Q_factor + + Returns: + output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)` + """ + + GAIN = 1 # TBD - add as a parameter + w0 = 2 * math.pi * cutoff_freq / sample_rate + A = math.exp(GAIN / 40.0 * math.log(10)) + alpha = math.sin(w0) / 2 / Q + mult = _dB2Linear(max(GAIN, 0)) + + b0 = (1 + math.cos(w0)) / 2 + b1 = -1 - math.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * math.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2) + + +def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707): + # type: (Tensor, int, float, Optional[float]) -> Tensor + r"""Designs biquad lowpass filter and performs filtering. Similar to SoX implementation. + + Args: + waveform (torch.Tensor): audio waveform of dimension of `(n_channel, n_frames)` + sample_rate (int): sampling rate of the waveform, e.g. 44100 (Hz) + cutoff_freq (float): filter cutoff frequency + Q (float): https://en.wikipedia.org/wiki/Q_factor + + Returns: + output_waveform (torch.Tensor): Dimension of `(n_channel, n_frames)` + """ + + GAIN = 1 + w0 = 2 * math.pi * cutoff_freq / sample_rate + A = math.exp(GAIN / 40.0 * math.log(10)) + alpha = math.sin(w0) / 2 / Q + mult = _dB2Linear(max(GAIN, 0)) + + b0 = (1 - math.cos(w0)) / 2 + b1 = 1 - math.cos(w0) + b2 = b0 + a0 = 1 + alpha + a1 = -2 * math.cos(w0) + a2 = 1 - alpha + return biquad(waveform, b0, b1, b2, a0, a1, a2)