diff --git a/test/test_functional.py b/test/test_functional.py index 4e36b4dce2..90bc1f541f 100644 --- a/test/test_functional.py +++ b/test/test_functional.py @@ -5,6 +5,7 @@ import torch import torchaudio import torchaudio.functional as F +import torchaudio.transforms as T import pytest import unittest import common_utils @@ -31,8 +32,10 @@ class TestFunctional(unittest.TestCase): specgram = torch.tensor([1., 2., 3., 4.]) test_dirpath, test_dir = common_utils.create_temp_assets_dir() + test_filepath = os.path.join(test_dirpath, 'assets', 'steam-train-whistle-daniel_simon.mp3') + waveform_train, sr_train = torchaudio.load(test_filepath) def test_torchscript_spectrogram(self): @@ -365,8 +368,63 @@ def test_create_fb(self): self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0) - def test_pitch(self): + def test_gain(self): + waveform_gain = F.gain(self.waveform_train, 3) + self.assertTrue(waveform_gain.abs().max().item(), 1.) + + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(self.test_filepath) + E.append_effect_to_chain("gain", [3]) + sox_gain_waveform = E.sox_build_flow_effects()[0] + + self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04)) + + def test_scale_to_interval(self): + scaled = 5.5 # [-5.5, 5.5] + waveform_scaled = F._scale_to_interval(self.waveform_train, scaled) + + self.assertTrue(torch.max(waveform_scaled) <= scaled) + self.assertTrue(torch.min(waveform_scaled) >= -scaled) + + def test_dither(self): + waveform_dithered = F.dither(self.waveform_train) + waveform_dithered_noiseshaped = F.dither(self.waveform_train, noise_shaping=True) + + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(self.test_filepath) + E.append_effect_to_chain("dither", []) + sox_dither_waveform = E.sox_build_flow_effects()[0] + + self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04)) + E.clear_chain() + + E.append_effect_to_chain("dither", ["-s"]) + sox_dither_waveform_ns = E.sox_build_flow_effects()[0] + + self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02)) + + def test_vctk_transform_pipeline(self): + test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav") + wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk) + + # rate + sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation') + wf_vctk = sample(wf_vctk) + # dither + wf_vctk = F.dither(wf_vctk, noise_shaping=True) + E = torchaudio.sox_effects.SoxEffectsChain() + E.set_input_file(test_filepath_vctk) + E.append_effect_to_chain("gain", ["-h"]) + E.append_effect_to_chain("channels", [1]) + E.append_effect_to_chain("rate", [16000]) + E.append_effect_to_chain("gain", ["-rh"]) + E.append_effect_to_chain("dither", ["-s"]) + wf_vctk_sox = E.sox_build_flow_effects()[0] + + self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03)) + + def test_pitch(self): test_dirpath, test_dir = common_utils.create_temp_assets_dir() test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav") test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav") @@ -518,6 +576,25 @@ def test_mask_along_axis_iid(self): _test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) + def test_torchscript_gain(self): + tensor = torch.rand((1, 1000)) + gainDB = 2.0 + + _test_torchscript_functional(F.gain, tensor, gainDB) + + def test_torchscript_scale_to_interval(self): + tensor = torch.rand((1, 1000)) + scaled = 3.5 + + _test_torchscript_functional(F._scale_to_interval, tensor, scaled) + + def test_torchscript_dither(self): + tensor = torch.rand((1, 1000)) + + _test_torchscript_functional(F.dither, tensor) + _test_torchscript_functional(F.dither, tensor, "RPDF") + _test_torchscript_functional(F.dither, tensor, "GPDF") + @pytest.mark.parametrize('complex_tensor', [ torch.randn(1, 2, 1025, 400, 2), diff --git a/torchaudio/datasets/vctk.py b/torchaudio/datasets/vctk.py index 813a9df62a..4e57b2b398 100644 --- a/torchaudio/datasets/vctk.py +++ b/torchaudio/datasets/vctk.py @@ -21,18 +21,16 @@ def load_vctk_item( # Read wav file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio) + waveform, sample_rate = torchaudio.load(file_audio) if downsample: - # Legacy - E = torchaudio.sox_effects.SoxEffectsChain() - E.set_input_file(file_audio) - E.append_effect_to_chain("gain", ["-h"]) - E.append_effect_to_chain("channels", [1]) - E.append_effect_to_chain("rate", [16000]) - E.append_effect_to_chain("gain", ["-rh"]) - E.append_effect_to_chain("dither", ["-s"]) - waveform, sample_rate = E.sox_build_flow_effects() - else: - waveform, sample_rate = torchaudio.load(file_audio) + # TODO Remove this parameter after deprecation + F = torchaudio.functional + T = torchaudio.transforms + # rate + sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation') + waveform = sample(waveform) + # dither + waveform = F.dither(waveform, noise_shaping=True) return waveform, sample_rate, utterance, speaker_id, utterance_id diff --git a/torchaudio/functional.py b/torchaudio/functional.py index 496de76357..8a5dae8698 100644 --- a/torchaudio/functional.py +++ b/torchaudio/functional.py @@ -858,6 +858,162 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): return output +def gain(waveform, gain_db=1.0): + # type: (Tensor, float) -> Tensor + r"""Apply amplification or attenuation to the whole waveform. + + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time). + gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`). + + Returns: + torch.Tensor: the whole waveform amplified by gain_db. + """ + if (gain_db == 0): + return waveform + + ratio = 10 ** (gain_db / 20) + + return waveform * ratio + + +def _scale_to_interval(waveform, interval_max=1.0): + # type: (Tensor, float) -> Tensor + r"""Scale the waveform to the interval [-interval_max, interval_max] across all dimensions. + + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time). + interval_max (float): The bounds of the interval, where the float indicates + the upper bound and the negative of the float indicates the lower + bound (Default: `1.0`). + Example: interval=1.0 -> [-1.0, 1.0] + + Returns: + torch.Tensor: the whole waveform scaled to interval. + """ + abs_max = torch.max(torch.abs(waveform)) + ratio = abs_max / interval_max + waveform /= ratio + + return waveform + + +def _add_noise_shaping(dithered_waveform, waveform): + r"""Noise shaping is calculated by error: + error[n] = dithered[n] - original[n] + noise_shaped_waveform[n] = dithered[n] + error[n-1] + """ + wf_shape = waveform.size() + waveform = waveform.reshape(-1, wf_shape[-1]) + + dithered_shape = dithered_waveform.size() + dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1]) + + error = dithered_waveform - waveform + + # add error[n-1] to dithered_waveform[n], so offset the error by 1 index + for index in range(error.size()[0]): + err = error[index] + error_offset = torch.cat((torch.zeros(1), err)) + error[index] = error_offset[:waveform.size()[1]] + + noise_shaped = dithered_waveform + error + return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:]) + + +def _apply_probability_distribution(waveform, density_function="TPDF"): + # type: (Tensor, str) -> Tensor + r"""Apply a probability distribution function on a waveform. + + Triangular probability density function (TPDF) dither noise has a + triangular distribution; values in the center of the range have a higher + probability of occurring. + + Rectangular probability density function (RPDF) dither noise has a + uniform distribution; any value in the specified range has the same + probability of occurring. + + Gaussian probability density function (GPDF) has a normal distribution. + The relationship of probabilities of results follows a bell-shaped, + or Gaussian curve, typical of dither generated by analog sources. + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + probability_density_function (string): The density function of a + continuous random variable (Default: `TPDF`) + Options: Triangular Probability Density Function - `TPDF` + Rectangular Probability Density Function - `RPDF` + Gaussian Probability Density Function - `GPDF` + Returns: + torch.Tensor: waveform dithered with TPDF + """ + shape = waveform.size() + waveform = waveform.reshape(-1, shape[-1]) + + channel_size = waveform.size()[0] - 1 + time_size = waveform.size()[-1] - 1 + + random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0 + random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0 + + number_of_bits = 16 + up_scaling = 2 ** (number_of_bits - 1) - 2 + signal_scaled = waveform * up_scaling + down_scaling = 2 ** (number_of_bits - 1) + + signal_scaled_dis = waveform + if (density_function == "RPDF"): + RPDF = waveform[random_channel][random_time] - 0.5 + + signal_scaled_dis = signal_scaled + RPDF + elif (density_function == "GPDF"): + # TODO Replace by distribution code once + # https://github.com/pytorch/pytorch/issues/29843 is resolved + # gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample() + + num_rand_variables = 6 + + gaussian = waveform[random_channel][random_time] + for ws in num_rand_variables * [time_size]: + rand_chan = int(torch.randint(channel_size, [1, ]).item()) + gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())] + + signal_scaled_dis = signal_scaled + gaussian + else: + TPDF = torch.bartlett_window(time_size + 1) + TPDF = TPDF.repeat((channel_size + 1), 1) + signal_scaled_dis = signal_scaled + TPDF + + quantised_signal_scaled = torch.round(signal_scaled_dis) + quantised_signal = quantised_signal_scaled / down_scaling + return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:]) + + +def dither(waveform, density_function="TPDF", noise_shaping=False): + # type: (Tensor, str, bool) -> Tensor + r"""Dither increases the perceived dynamic range of audio stored at a + particular bit-depth by eliminating nonlinear truncation distortion + (i.e. adding minimally perceived noise to mask distortion caused by quantization). + Args: + waveform (torch.Tensor): Tensor of audio of dimension (channel, time) + density_function (string): The density function of a + continuous random variable (Default: `TPDF`) + Options: Triangular Probability Density Function - `TPDF` + Rectangular Probability Density Function - `RPDF` + Gaussian Probability Density Function - `GPDF` + noise_shaping (boolean): a filtering process that shapes the spectral + energy of quantisation error (Default: `False`) + + Returns: + torch.Tensor: waveform dithered + """ + dithered = _apply_probability_distribution(waveform, density_function=density_function) + + if noise_shaping: + return _add_noise_shaping(dithered, waveform) + else: + return dithered + + def _compute_nccf(waveform, sample_rate, frame_time, freq_low): # type: (Tensor, int, float, int) -> Tensor r"""