diff --git a/docs/source/index.rst b/docs/source/index.rst index c6d0efde69..cee5075d92 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -13,6 +13,7 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio kaldi_io transforms functional + utils .. automodule:: torchaudio :members: diff --git a/docs/source/sox_effects.rst b/docs/source/sox_effects.rst index 56cd985d0a..46c0a74552 100644 --- a/docs/source/sox_effects.rst +++ b/docs/source/sox_effects.rst @@ -4,10 +4,16 @@ torchaudio.sox_effects ====================== -Create SoX effects chain for preprocessing audio. - .. currentmodule:: torchaudio.sox_effects +Apply SoX effects chain on torch.Tensor or on file and load as torch.Tensor. + +.. autofunction:: apply_effects_tensor + +.. autofunction:: apply_effects_file + +Create SoX effects chain for preprocessing audio. + :hidden:`SoxEffect` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ diff --git a/docs/source/utils.rst b/docs/source/utils.rst new file mode 100644 index 0000000000..b56aabb7bb --- /dev/null +++ b/docs/source/utils.rst @@ -0,0 +1,21 @@ +.. role:: hidden + :class: hidden-section + +torchaudio.utils.sox_utils +========================== + +Utility module to configure libsox. This affects functionalities in ``sox_io`` backend and ``torchaudio.sox_effects``. + +.. currentmodule:: torchaudio.utils.sox_utils + +.. autofunction:: set_seed + +.. autofunction:: set_verbosity + +.. autofunction:: set_buffer_size + +.. autofunction:: set_use_threads + +.. autofunction:: list_effects + +.. autofunction:: list_formats diff --git a/test/assets/sox_effect_test_args.json b/test/assets/sox_effect_test_args.json new file mode 100644 index 0000000000..185e695583 --- /dev/null +++ b/test/assets/sox_effect_test_args.json @@ -0,0 +1,88 @@ +{"effects": [["allpass", "300", "10"]]} +{"effects": [["band", "300", "10"]]} +{"effects": [["bandpass", "300", "10"]]} +{"effects": [["bandreject", "300", "10"]]} +{"effects": [["bass", "-10"]]} +{"effects": [["bend", ".35,180,.25", ".15,740,.53", "0,-520,.3"]]} +{"effects": [["biquad", "0.4", "0.2", "0.9", "0.7", "0.2", "0.6"]]} +{"effects": [["chorus", "0.7", "0.9", "55", "0.4", "0.25", "2", "-t"]]} +{"effects": [["chorus", "0.6", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "1.3", "-s"]]} +{"effects": [["chorus", "0.5", "0.9", "50", "0.4", "0.25", "2", "-t", "60", "0.32", "0.4", "2.3", "-t", "40", "0.3", "0.3", "1.3", "-s"]]} +{"effects": [["channels", "1"]]} +{"effects": [["channels", "2"]]} +{"effects": [["channels", "3"]]} +{"effects": [["compand", "0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"]]} +{"effects": [["compand", ".1,.2", "-inf,-50.1,-inf,-50,-50", "0", "-90", ".1"]]} +{"effects": [["compand", ".1,.1", "-45.1,-45,-inf,0,-inf", "45", "-90", ".1"]]} +{"effects": [["contrast", "0"]]} +{"effects": [["contrast", "25"]]} +{"effects": [["contrast", "50"]]} +{"effects": [["contrast", "75"]]} +{"effects": [["contrast", "100"]]} +{"effects": [["dcshift", "1.0"]]} +{"effects": [["dcshift", "-1.0"]]} +{"effects": [["deemph"]], "input_sample_rate": 44100} +{"effects": [["delay", "1.5", "+1"]]} +{"effects": [["dither", "-s"]]} +{"effects": [["dither", "-S"]]} +{"effects": [["divide"]]} +{"effects": [["downsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 4000} +{"effects": [["earwax"]], "input_sample_rate": 44100} +{"effects": [["echo", "0.8", "0.88", "60", "0.4"]]} +{"effects": [["echo", "0.8", "0.88", "6", "0.4"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3"]]} +{"effects": [["echo", "0.8", "0.9", "1000", "0.3", "1800", "0.25"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "700", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "700", "0.25", "900", "0.3"]]} +{"effects": [["echos", "0.8", "0.7", "40", "0.25", "63", "0.3"]]} +{"effects": [["equalizer", "300", "10", "5"]]} +{"effects": [["fade", "q", "3"]]} +{"effects": [["fade", "h", "3"]]} +{"effects": [["fade", "t", "3"]]} +{"effects": [["fade", "l", "3"]]} +{"effects": [["fade", "p", "3"]]} +{"effects": [["fir", "0.0195", "-0.082", "0.234", "0.891", "-0.145", "0.043"]]} +{"effects": [["fir", "test/assets/sox_effect_test_fir_coeffs.txt"]]} +{"effects": [["flanger"]]} +{"effects": [["gain", "-n"]]} +{"effects": [["gain", "-n", "-3"]]} +{"effects": [["gain", "-l", "-6"]]} +{"effects": [["highpass", "-1", "300"]]} +{"effects": [["highpass", "-2", "300"]]} +{"effects": [["hilbert"]]} +{"effects": [["loudness"]]} +{"effects": [["lowpass", "-1", "300"]]} +{"effects": [["lowpass", "-2", "300"]]} +{"effects": [["mcompand", "0.005,0.1 -47,-40,-34,-34,-17,-33", "100", "0.003,0.05 -47,-40,-34,-34,-17,-33", "400", "0.000625,0.0125 -47,-40,-34,-34,-15,-33", "1600", "0.0001,0.025 -47,-40,-34,-34,-31,-31,-0,-30", "6400", "0,0.025 -38,-31,-28,-28,-0,-25"]], "input_sample_rate": 44100} +{"effects": [["norm"]]} +{"effects": [["oops"]]} +{"effects": [["overdrive"]]} +{"effects": [["pad"]]} +{"effects": [["phaser"]]} +{"effects": [["pitch", "6.48"], ["rate", "8030"]], "output_sample_rate": 8030} +{"effects": [["pitch", "-6.50"], ["rate", "7970"]], "output_sample_rate": 7970} +{"effects": [["rate", "4567"]], "output_sample_rate": 4567} +{"effects": [["remix", "6", "7", "8", "0"]], "num_channels": 8} +{"effects": [["remix", "1-3,7", "3"]], "num_channels": 8} +{"effects": [["repeat"]]} +{"effects": [["reverb"]]} +{"effects": [["reverse"]]} +{"effects": [["riaa"]], "input_sample_rate": 44100} +{"effects": [["silence", "0"]]} +{"effects": [["sinc", "3k"]]} +{"effects": [["speed", "1.3"]], "input_sample_rate": 4000, "output_sample_rate": 5200} +{"effects": [["speed", "0.7"]], "input_sample_rate": 4000, "output_sample_rate": 2800} +{"effects": [["stat"]]} +{"effects": [["stats"]]} +{"effects": [["stretch"]]} +{"effects": [["swap"]]} +{"effects": [["synth"]]} +{"effects": [["tempo", "0.9"]]} +{"effects": [["tempo", "1.1"]]} +{"effects": [["treble", "3"]]} +{"effects": [["tremolo", "300", "40"]]} +{"effects": [["tremolo", "300", "50"]]} +{"effects": [["trim", "0", "0.1"]]} +{"effects": [["upsample", "2"]], "input_sample_rate": 8000, "output_sample_rate": 16000} +{"effects": [["vad"]]} +{"effects": [["vol", "3"]]} diff --git a/test/assets/sox_effect_test_fir_coeffs.txt b/test/assets/sox_effect_test_fir_coeffs.txt new file mode 100644 index 0000000000..903a607d3b --- /dev/null +++ b/test/assets/sox_effect_test_fir_coeffs.txt @@ -0,0 +1 @@ +0.0195 -0.082 0.234 0.891 -0.145 0.043 diff --git a/test/common_utils/data_utils.py b/test/common_utils/data_utils.py index b948ce334a..321e24ce0a 100644 --- a/test/common_utils/data_utils.py +++ b/test/common_utils/data_utils.py @@ -72,6 +72,7 @@ def get_sinusoid( n_channels: int = 1, dtype: Union[str, torch.dtype] = "float32", device: Union[str, torch.device] = "cpu", + channels_first: bool = True, ): """Generate pseudo audio data with sine wave. @@ -91,4 +92,7 @@ def get_sinusoid( pie2 = 2 * 3.141592653589793 end = pie2 * frequency * duration theta = torch.linspace(0, end, sample_rate * duration, dtype=dtype, device=device) - return torch.sin(theta, out=None).repeat([n_channels, 1]) + sin = torch.sin(theta, out=None).repeat([n_channels, 1]) + if not channels_first: + sin = sin.t() + return sin diff --git a/test/common_utils/sox_utils.py b/test/common_utils/sox_utils.py index cd1c247b72..db131cdec5 100644 --- a/test/common_utils/sox_utils.py +++ b/test/common_utils/sox_utils.py @@ -77,3 +77,24 @@ def convert_audio_file( command += [dst_path] print(' '.join(command)) subprocess.run(command, check=True) + + +def _flattern(effects): + if not effects: + return effects + if isinstance(effects[0], str): + return effects + return [item for sublist in effects for item in sublist] + + +def run_sox_effect(input_file, output_file, effect, *, output_sample_rate=None, output_bitdepth=None): + """Run sox effects""" + effect = _flattern(effect) + command = ['sox', '-V', '--no-dither', input_file] + if output_bitdepth: + command += ['--bits', str(output_bitdepth)] + command += [output_file] + effect + if output_sample_rate: + command += ['rate', str(output_sample_rate)] + print(' '.join(command)) + subprocess.run(command, check=True) diff --git a/test/common_utils/test_case_utils.py b/test/common_utils/test_case_utils.py index f3b0c343a6..253e2166fb 100644 --- a/test/common_utils/test_case_utils.py +++ b/test/common_utils/test_case_utils.py @@ -14,33 +14,28 @@ class TempDirMixin: """Mixin to provide easy access to temp dir""" temp_dir_ = None - base_temp_dir = None - temp_dir = None - @classmethod - def setUpClass(cls): - super().setUpClass() + @property + def base_temp_dir(self): # If TORCHAUDIO_TEST_TEMP_DIR is set, use it instead of temporary directory. # this is handy for debugging. key = 'TORCHAUDIO_TEST_TEMP_DIR' if key in os.environ: - cls.base_temp_dir = os.environ[key] - else: - cls.temp_dir_ = tempfile.TemporaryDirectory() - cls.base_temp_dir = cls.temp_dir_.name + return os.environ[key] + if self.__class__.temp_dir_ is None: + self.__class__.temp_dir_ = tempfile.TemporaryDirectory() + return self.__class__.temp_dir_.name @classmethod def tearDownClass(cls): super().tearDownClass() - if isinstance(cls.temp_dir_, tempfile.TemporaryDirectory): + if cls.temp_dir_ is not None: cls.temp_dir_.cleanup() - - def setUp(self): - super().setUp() - self.temp_dir = os.path.join(self.base_temp_dir, self.id()) + cls.temp_dir_ = None def get_temp_path(self, *paths): - path = os.path.join(self.temp_dir, *paths) + temp_dir = os.path.join(self.base_temp_dir, self.id()) + path = os.path.join(temp_dir, *paths) os.makedirs(os.path.dirname(path), exist_ok=True) return path diff --git a/test/sox_effect/__init__.py b/test/sox_effect/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/sox_effect/common.py b/test/sox_effect/common.py new file mode 100644 index 0000000000..2eafa1e992 --- /dev/null +++ b/test/sox_effect/common.py @@ -0,0 +1,6 @@ +def name_func(func, _, params): + if isinstance(params.args[0], str): + args = "_".join([str(arg) for arg in params.args]) + else: + args = "_".join([str(arg) for arg in params.args[0]]) + return f'{func.__name__}_{args}' diff --git a/test/sox_effect/test_dataset.py b/test/sox_effect/test_dataset.py new file mode 100644 index 0000000000..51aeb975aa --- /dev/null +++ b/test/sox_effect/test_dataset.py @@ -0,0 +1,115 @@ +from typing import List, Tuple + +import numpy as np +import torch +import torchaudio + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExtension, + get_whitenoise, + load_wav, + save_wav, +) + + +class RandomPerturbationFile(torch.utils.data.Dataset): + """Given flist, apply random speed perturbation""" + def __init__(self, flist: List[str], sample_rate: int): + super().__init__() + self.flist = flist + self.sample_rate = sample_rate + self.rng = None + + def __getitem__(self, index): + speed = self.rng.uniform(0.5, 2.0) + effects = [ + ['gain', '-n', '-10'], + ['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. + ['rate', f'{self.sample_rate}'], + ['pad', '0', '1.5'], # add 1.5 seconds silence at the end + ['trim', '0', '2'], # get the first 2 seconds + ] + data, _ = torchaudio.sox_effects.apply_effects_file(self.flist[index], effects) + return data + + def __len__(self): + return len(self.flist) + + +class RandomPerturbationTensor(torch.utils.data.Dataset): + """Apply speed purturbation to (synthetic) Tensor data""" + def __init__(self, signals: List[Tuple[torch.Tensor, int]], sample_rate: int): + super().__init__() + self.signals = signals + self.sample_rate = sample_rate + self.rng = None + + def __getitem__(self, index): + speed = self.rng.uniform(0.5, 2.0) + effects = [ + ['gain', '-n', '-10'], + ['speed', f'{speed:.5f}'], # duration of data is 0.5 ~ 2.0 seconds. + ['rate', f'{self.sample_rate}'], + ['pad', '0', '1.5'], # add 1.5 seconds silence at the end + ['trim', '0', '2'], # get the first 2 seconds + ] + tensor, sample_rate = self.signals[index] + data, _ = torchaudio.sox_effects.apply_effects_tensor(tensor, sample_rate, effects) + return data + + def __len__(self): + return len(self.signals) + + +def init_random_seed(worker_id): + dataset = torch.utils.data.get_worker_info().dataset + dataset.rng = np.random.RandomState(worker_id) + + +@skipIfNoExtension +class TestSoxEffectsDataset(TempDirMixin, PytorchTestCase): + """Test `apply_effects_file` in multi-process dataloader setting""" + + def _generate_dataset(self, num_samples=128): + flist = [] + for i in range(num_samples): + sample_rate = np.random.choice([8000, 16000, 44100]) + dtype = np.random.choice(['float32', 'int32', 'int16', 'uint8']) + data = get_whitenoise(n_channels=2, sample_rate=sample_rate, duration=1, dtype=dtype) + path = self.get_temp_path(f'{i:03d}_{dtype}_{sample_rate}.wav') + save_wav(path, data, sample_rate) + flist.append(path) + return flist + + def test_apply_effects_file(self): + sample_rate = 12000 + flist = self._generate_dataset() + dataset = RandomPerturbationFile(flist, sample_rate) + loader = torch.utils.data.DataLoader( + dataset, batch_size=32, num_workers=16, + worker_init_fn=init_random_seed, + ) + for batch in loader: + assert batch.shape == (32, 2, 2 * sample_rate) + + def _generate_signals(self, num_samples=128): + signals = [] + for _ in range(num_samples): + sample_rate = np.random.choice([8000, 16000, 44100]) + data = get_whitenoise( + n_channels=2, sample_rate=sample_rate, duration=1, dtype='float32') + signals.append((data, sample_rate)) + return signals + + def test_apply_effects_tensor(self): + sample_rate = 12000 + signals = self._generate_signals() + dataset = RandomPerturbationTensor(signals, sample_rate) + loader = torch.utils.data.DataLoader( + dataset, batch_size=32, num_workers=16, + worker_init_fn=init_random_seed, + ) + for batch in loader: + assert batch.shape == (32, 2, 2 * sample_rate) diff --git a/test/sox_effect/test_sox_effect.py b/test/sox_effect/test_sox_effect.py new file mode 100644 index 0000000000..8a619d52b9 --- /dev/null +++ b/test/sox_effect/test_sox_effect.py @@ -0,0 +1,221 @@ +import itertools + +from torchaudio import sox_effects +from parameterized import parameterized + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExtension, + get_sinusoid, + get_wav_data, + save_wav, + load_wav, + load_params, + sox_utils, +) +from .common import ( + name_func, +) + + +@skipIfNoExtension +class TestSoxEffects(PytorchTestCase): + def test_init(self): + """Calling init_sox_effects multiple times does not crush""" + for _ in range(3): + sox_effects.init_sox_effects() + + +@skipIfNoExtension +class TestSoxEffectsTensor(TempDirMixin, PytorchTestCase): + """Test suite for `apply_effects_tensor` function""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2, 4, 8], + [True, False] + )), name_func=name_func) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_tensor` without effects should return identical data as input""" + original = get_wav_data(dtype, num_channels, channels_first=channels_first) + expected = original.clone() + found, output_sample_rate = sox_effects.apply_effects_tensor( + expected, sample_rate, [], channels_first) + + assert output_sample_rate == sample_rate + # SoxEffect should not alter the input Tensor object + self.assertEqual(original, expected) + # SoxEffect should not return the same Tensor object + assert expected is not found + # Returned Tensor should equal to the input Tensor + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.json"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects(self, args): + """`apply_effects_tensor` should return identical data as sox command""" + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + + original = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32') + save_wav(input_path, original, input_sr) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_tensor(original, input_sr, effects) + + assert sr == expected_sr + self.assertEqual(expected, found) + + +@skipIfNoExtension +class TestSoxEffectsFile(TempDirMixin, PytorchTestCase): + """Test suite for `apply_effects_file` function""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2, 4, 8], + [False, True], + )), name_func=name_func) + def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first): + """`apply_effects_file` without effects should return identical data as input""" + path = self.get_temp_path('input.wav') + expected = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(path, expected, sample_rate, channels_first=channels_first) + + found, output_sample_rate = sox_effects.apply_effects_file( + path, [], normalize=False, channels_first=channels_first) + + assert output_sample_rate == sample_rate + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.json"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects(self, args): + """`apply_effects_file` should return identical data as sox command""" + dtype = 'int32' + channels_first = True + effects = args['effects'] + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + output_sr = args.get("output_sample_rate") + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + +@skipIfNoExtension +class TestFileFormats(TempDirMixin, PytorchTestCase): + """`apply_effects_file` gives the same result as sox on various file formats""" + @parameterized.expand(list(itertools.product( + ['float32', 'int32', 'int16', 'uint8'], + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_wav(self, dtype, sample_rate, num_channels): + """`apply_effects_file` works on various wav format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, sample_rate, channels_first=channels_first) + sox_utils.run_sox_effect(input_path, reference_path, effects) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_mp3(self, sample_rate, num_channels): + """`apply_effects_file` works on various mp3 format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.mp3') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected, atol=1e-4, rtol=1e-8) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_flac(self, sample_rate, num_channels): + """`apply_effects_file` works on various flac format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.flac') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + + @parameterized.expand(list(itertools.product( + [8000, 16000], + [1, 2], + )), name_func=lambda f, _, p: f'{f.__name__}_{"_".join(str(arg) for arg in p.args)}') + def test_vorbis(self, sample_rate, num_channels): + """`apply_effects_file` works on various vorbis format""" + channels_first = True + effects = [['band', '300', '10']] + + input_path = self.get_temp_path('input.vorbis') + reference_path = self.get_temp_path('reference.wav') + sox_utils.gen_audio_file(input_path, sample_rate, num_channels) + sox_utils.run_sox_effect(input_path, reference_path, effects, output_bitdepth=32) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + input_path, effects, channels_first=channels_first) + save_wav(self.get_temp_path('result.wav'), found, sr, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) diff --git a/test/sox_effect/test_torchscript.py b/test/sox_effect/test_torchscript.py new file mode 100644 index 0000000000..8202a2ebb0 --- /dev/null +++ b/test/sox_effect/test_torchscript.py @@ -0,0 +1,98 @@ +from typing import List + +import torch +from torchaudio import sox_effects +from parameterized import parameterized + +from ..common_utils import ( + TempDirMixin, + PytorchTestCase, + skipIfNoExtension, + get_sinusoid, + load_params, + save_wav, +) + + +class SoxEffectTensorTransform(torch.nn.Module): + effects: List[List[str]] + + def __init__(self, effects: List[List[str]], sample_rate: int, channels_first: bool): + super().__init__() + self.effects = effects + self.sample_rate = sample_rate + self.channels_first = channels_first + + def forward(self, tensor: torch.Tensor): + return sox_effects.apply_effects_tensor( + tensor, self.sample_rate, self.effects, self.channels_first) + + +class SoxEffectFileTransform(torch.nn.Module): + effects: List[List[str]] + channels_first: bool + + def __init__(self, effects: List[List[str]], channels_first: bool): + super().__init__() + self.effects = effects + self.channels_first = channels_first + + def forward(self, path: str): + return sox_effects.apply_effects_file(path, self.effects, self.channels_first) + + +@skipIfNoExtension +class TestTorchScript(TempDirMixin, PytorchTestCase): + @parameterized.expand( + load_params("sox_effect_test_args.json"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_tensor(self, args): + effects = args['effects'] + channels_first = True + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + trans = SoxEffectTensorTransform(effects, input_sr, channels_first) + + path = self.get_temp_path('sox_effect.zip') + torch.jit.script(trans).save(path) + trans = torch.jit.load(path) + + wav = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32', channels_first=channels_first) + found, sr_found = trans(wav) + expected, sr_expected = sox_effects.apply_effects_tensor( + wav, input_sr, effects, channels_first) + + assert sr_found == sr_expected + self.assertEqual(expected, found) + + @parameterized.expand( + load_params("sox_effect_test_args.json"), + name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', + ) + def test_apply_effects_file(self, args): + effects = args['effects'] + channels_first = True + num_channels = args.get("num_channels", 2) + input_sr = args.get("input_sample_rate", 8000) + + trans = SoxEffectFileTransform(effects, channels_first) + + path = self.get_temp_path('sox_effect.zip') + torch.jit.script(trans).save(path) + trans = torch.jit.load(path) + + path = self.get_temp_path('input.wav') + wav = get_sinusoid( + frequency=800, sample_rate=input_sr, + n_channels=num_channels, dtype='float32', channels_first=channels_first) + save_wav(path, wav, sample_rate=input_sr, channels_first=channels_first) + + found, sr_found = trans(path) + expected, sr_expected = sox_effects.apply_effects_file(path, effects, channels_first) + + assert sr_found == sr_expected + self.assertEqual(expected, found) diff --git a/test/utils/__init__.py b/test/utils/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/test/utils/test_sox_utils.py b/test/utils/test_sox_utils.py new file mode 100644 index 0000000000..c87bf0b92e --- /dev/null +++ b/test/utils/test_sox_utils.py @@ -0,0 +1,44 @@ +from torchaudio.utils import sox_utils + +from ..common_utils import ( + PytorchTestCase, + skipIfNoExtension, +) + + +@skipIfNoExtension +class TestSoxUtils(PytorchTestCase): + """Smoke tests for sox_util module""" + def test_set_seed(self): + """`set_seed` does not crush""" + sox_utils.set_seed(0) + + def test_set_verbosity(self): + """`set_verbosity` does not crush""" + for val in range(6, 0, -1): + sox_utils.set_verbosity(val) + + def test_set_buffer_size(self): + """`set_buffer_size` does not crush""" + sox_utils.set_buffer_size(131072) + # back to default + sox_utils.set_buffer_size(8192) + + def test_set_use_threads(self): + """`set_use_threads` does not crush""" + sox_utils.set_use_threads(True) + # back to default + sox_utils.set_use_threads(False) + + def test_list_effects(self): + """`list_effects` returns the list of available effects""" + effects = sox_utils.list_effects() + # We cannot infer what effects are available, so only check some of them. + assert 'highpass' in effects + assert 'phaser' in effects + assert 'gain' in effects + + def test_list_formats(self): + """`list_formats` returns the list of supported formats""" + formats = sox_utils.list_formats() + assert 'wav' in formats diff --git a/torchaudio/__init__.py b/torchaudio/__init__.py index 16419ff3ff..f748d861fd 100644 --- a/torchaudio/__init__.py +++ b/torchaudio/__init__.py @@ -4,6 +4,7 @@ compliance, datasets, kaldi_io, + utils, sox_effects, transforms ) diff --git a/torchaudio/csrc/register.cpp b/torchaudio/csrc/register.cpp index 3c03232941..9b377a3676 100644 --- a/torchaudio/csrc/register.cpp +++ b/torchaudio/csrc/register.cpp @@ -18,6 +18,17 @@ static auto registerTensorSignal = .def("get_sample_rate", &sox_utils::TensorSignal::getSampleRate) .def("get_channels_first", &sox_utils::TensorSignal::getChannelsFirst); +static auto registerSetSoxOptions = + torch::RegisterOperators() + .op("torchaudio::sox_utils_set_seed", &sox_utils::set_seed) + .op("torchaudio::sox_utils_set_verbosity", &sox_utils::set_verbosity) + .op("torchaudio::sox_utils_set_use_threads", + &sox_utils::set_use_threads) + .op("torchaudio::sox_utils_set_buffer_size", + &sox_utils::set_buffer_size) + .op("torchaudio::sox_utils_list_effects", &sox_utils::list_effects) + .op("torchaudio::sox_utils_list_formats", &sox_utils::list_formats); + //////////////////////////////////////////////////////////////////////////////// // sox_io.h //////////////////////////////////////////////////////////////////////////////// @@ -53,12 +64,23 @@ static auto registerSaveAudioFile = torch::RegisterOperators().op( // sox_effects.h //////////////////////////////////////////////////////////////////////////////// static auto registerSoxEffects = - torch::RegisterOperators( - "torchaudio::sox_effects_initialize_sox_effects", - &sox_effects::initialize_sox_effects) + torch::RegisterOperators() + .op("torchaudio::sox_effects_initialize_sox_effects", + &sox_effects::initialize_sox_effects) .op("torchaudio::sox_effects_shutdown_sox_effects", &sox_effects::shutdown_sox_effects) - .op("torchaudio::sox_effects_list_effects", &sox_effects::list_effects); + .op(torch::RegisterOperators::options() + .schema( + "torchaudio::sox_effects_apply_effects_tensor(__torch__.torch.classes.torchaudio.TensorSignal input_signal, str[][] effects) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal") + .catchAllKernel< + decltype(sox_effects::apply_effects_tensor), + &sox_effects::apply_effects_tensor>()) + .op(torch::RegisterOperators::options() + .schema( + "torchaudio::sox_effects_apply_effects_file(str path, str[][] effects, bool normalize, bool channels_first) -> __torch__.torch.classes.torchaudio.TensorSignal output_signal") + .catchAllKernel< + decltype(sox_effects::apply_effects_file), + &sox_effects::apply_effects_file>()); } // namespace } // namespace torchaudio diff --git a/torchaudio/csrc/sox_effects.cpp b/torchaudio/csrc/sox_effects.cpp index 9a0c2ddc6f..8a7cc7c494 100644 --- a/torchaudio/csrc/sox_effects.cpp +++ b/torchaudio/csrc/sox_effects.cpp @@ -1,7 +1,9 @@ #include #include +#include +#include -using namespace torch::indexing; +using namespace torchaudio::sox_utils; namespace torchaudio { namespace sox_effects { @@ -10,44 +12,125 @@ namespace { enum SoxEffectsResourceState { NotInitialized, Initialized, ShutDown }; SoxEffectsResourceState SOX_RESOURCE_STATE = NotInitialized; +std::mutex SOX_RESOUCE_STATE_MUTEX; } // namespace void initialize_sox_effects() { - if (SOX_RESOURCE_STATE == ShutDown) { - throw std::runtime_error( - "SoX Effects has been shut down. Cannot initialize again."); - } - if (SOX_RESOURCE_STATE == NotInitialized) { - if (sox_init() != SOX_SUCCESS) { - throw std::runtime_error("Failed to initialize sox effects."); - }; - SOX_RESOURCE_STATE = Initialized; + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + if (sox_init() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = Initialized; + case Initialized: + break; + case ShutDown: + throw std::runtime_error( + "SoX Effects has been shut down. Cannot initialize again."); } }; void shutdown_sox_effects() { - if (SOX_RESOURCE_STATE == NotInitialized) { - throw std::runtime_error( - "SoX Effects is not initialized. Cannot shutdown."); + const std::lock_guard lock(SOX_RESOUCE_STATE_MUTEX); + + switch (SOX_RESOURCE_STATE) { + case NotInitialized: + throw std::runtime_error( + "SoX Effects is not initialized. Cannot shutdown."); + case Initialized: + if (sox_quit() != SOX_SUCCESS) { + throw std::runtime_error("Failed to initialize sox effects."); + }; + SOX_RESOURCE_STATE = ShutDown; + case ShutDown: + break; } - if (SOX_RESOURCE_STATE == Initialized) { - if (sox_quit() != SOX_SUCCESS) { - throw std::runtime_error("Failed to initialize sox effects."); - }; - SOX_RESOURCE_STATE = ShutDown; +} + +c10::intrusive_ptr apply_effects_tensor( + const c10::intrusive_ptr& input_signal, + std::vector> effects) { + auto in_tensor = input_signal->getTensor(); + validate_input_tensor(in_tensor); + + // Create SoxEffectsChain + const auto dtype = in_tensor.dtype(); + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/get_encodinginfo("wav", dtype, 0.), + /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + + // Prepare output buffer + std::vector out_buffer; + out_buffer.reserve(in_tensor.numel()); + + // Build and run effects chain + chain.addInputTensor(input_signal.get()); + for (const auto& effect : effects) { + chain.addEffect(effect); } + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + const auto channels_first = input_signal->getChannelsFirst(); + auto out_tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + /*noramlize=*/false, + channels_first); + + return c10::make_intrusive( + out_tensor, chain.getOutputSampleRate(), channels_first); } -std::vector list_effects() { - std::vector names; - const sox_effect_fn_t* fns = sox_get_effect_fns(); - for (int i = 0; fns[i]; ++i) { - const sox_effect_handler_t* handler = fns[i](); - if (handler && handler->name) - names.push_back(handler->name); +c10::intrusive_ptr apply_effects_file( + const std::string path, + std::vector> effects, + const bool normalize, + const bool channels_first) { + // Open input file + SoxFormat sf(sox_open_read( + path.c_str(), + /*signal=*/nullptr, + /*encoding=*/nullptr, + /*filetype=*/nullptr)); + + validate_input_file(sf); + + const auto dtype = get_dtype(sf->encoding.encoding, sf->signal.precision); + + // Prepare output + std::vector out_buffer; + out_buffer.reserve(sf->signal.length); + + // Create and run SoxEffectsChain + torchaudio::sox_effects_chain::SoxEffectsChain chain( + /*input_encoding=*/sf->encoding, + /*output_encoding=*/get_encodinginfo("wav", dtype, 0.)); + + chain.addInputFile(sf); + for (const auto& effect : effects) { + chain.addEffect(effect); } - return names; + chain.addOutputBuffer(&out_buffer); + chain.run(); + + // Create tensor from buffer + auto tensor = convert_to_tensor( + /*buffer=*/out_buffer.data(), + /*num_samples=*/out_buffer.size(), + /*num_channels=*/chain.getOutputNumChannels(), + dtype, + normalize, + channels_first); + + return c10::make_intrusive( + tensor, chain.getOutputSampleRate(), channels_first); } } // namespace sox_effects diff --git a/torchaudio/csrc/sox_effects.h b/torchaudio/csrc/sox_effects.h index 14bdbbfabc..7883a7beac 100644 --- a/torchaudio/csrc/sox_effects.h +++ b/torchaudio/csrc/sox_effects.h @@ -2,6 +2,7 @@ #define TORCHAUDIO_SOX_EFFECTS_H #include +#include namespace torchaudio { namespace sox_effects { @@ -10,7 +11,15 @@ void initialize_sox_effects(); void shutdown_sox_effects(); -std::vector list_effects(); +c10::intrusive_ptr apply_effects_tensor( + const c10::intrusive_ptr& input_signal, + std::vector> effects); + +c10::intrusive_ptr apply_effects_file( + const std::string path, + std::vector> effects, + const bool normalize = true, + const bool channels_first = true); } // namespace sox_effects } // namespace torchaudio diff --git a/torchaudio/csrc/sox_effects_chain.cpp b/torchaudio/csrc/sox_effects_chain.cpp new file mode 100644 index 0000000000..05d730b6e7 --- /dev/null +++ b/torchaudio/csrc/sox_effects_chain.cpp @@ -0,0 +1,236 @@ +#include +#include + +using namespace torch::indexing; +using namespace torchaudio::sox_utils; + +namespace torchaudio { +namespace sox_effects_chain { + +namespace { + +// Helper struct to safely close sox_effect_t* pointer returned by +// sox_create_effect +struct SoxEffect { + explicit SoxEffect(sox_effect_t* se) noexcept : se_(se){}; + SoxEffect(const SoxEffect& other) = delete; + SoxEffect(const SoxEffect&& other) = delete; + SoxEffect& operator=(const SoxEffect& other) = delete; + SoxEffect& operator=(SoxEffect&& other) = delete; + ~SoxEffect() { + if (se_ != nullptr) { + free(se_); + } + } + operator sox_effect_t*() const { + return se_; + }; + sox_effect_t* operator->() noexcept { + return se_; + } + + private: + sox_effect_t* se_; +}; + +/// helper classes for passing the location of input tensor and output buffer +/// +/// drain/flow callback functions require plaing C style function signature and +/// the way to pass extra data is to attach data to sox_fffect_t::priv pointer. +/// The following structs will be assigned to sox_fffect_t::priv pointer which +/// gives sox_effect_t an access to input Tensor and output buffer object. +struct TensorInputPriv { + size_t index; + TensorSignal* signal; +}; +struct TensorOutputPriv { + std::vector* buffer; +}; + +/// Callback function to feed Tensor data to SoxEffectChain. +int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) { + // Retrieve the input Tensor and current index + auto priv = static_cast(effp->priv); + auto index = priv->index; + auto signal = priv->signal; + auto tensor = signal->getTensor(); + auto num_channels = effp->out_signal.channels; + + // Adjust the number of samples to read + const size_t num_samples = tensor.numel(); + if (index + *osamp > num_samples) { + *osamp = num_samples - index; + } + // Ensure that it's a multiple of the number of channels + *osamp -= *osamp % num_channels; + + // Slice the input Tensor and unnormalize the values + const auto tensor_ = [&]() { + auto i_frame = index / num_channels; + auto num_frames = *osamp / num_channels; + auto t = (signal->getChannelsFirst()) + ? tensor.index({Slice(), Slice(i_frame, i_frame + num_frames)}).t() + : tensor.index({Slice(i_frame, i_frame + num_frames), Slice()}); + return unnormalize_wav(t.reshape({-1})).contiguous(); + }(); + priv->index += *osamp; + + // Write data to SoxEffectsChain buffer. + auto ptr = tensor_.data_ptr(); + std::copy(ptr, ptr + *osamp, obuf); + + return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS; +} + +/// Callback function to fetch data from SoxEffectChain. +int tensor_output_flow( + sox_effect_t* effp LSX_UNUSED, + sox_sample_t const* ibuf, + sox_sample_t* obuf LSX_UNUSED, + size_t* isamp, + size_t* osamp) { + *osamp = 0; + // Get output buffer + auto out_buffer = static_cast(effp->priv)->buffer; + // Append at the end + out_buffer->insert(out_buffer->end(), ibuf, ibuf + *isamp); + return SOX_SUCCESS; +} + +sox_effect_handler_t* get_tensor_input_handler() { + static sox_effect_handler_t handler{/*name=*/"input_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/NULL, + /*drain=*/tensor_input_drain, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorInputPriv)}; + return &handler; +} + +sox_effect_handler_t* get_tensor_output_handler() { + static sox_effect_handler_t handler{/*name=*/"output_tensor", + /*usage=*/NULL, + /*flags=*/SOX_EFF_MCHAN, + /*getopts=*/NULL, + /*start=*/NULL, + /*flow=*/tensor_output_flow, + /*drain=*/NULL, + /*stop=*/NULL, + /*kill=*/NULL, + /*priv_size=*/sizeof(TensorOutputPriv)}; + return &handler; +} + +} // namespace + +SoxEffectsChain::SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding) + : in_enc_(input_encoding), + out_enc_(output_encoding), + in_sig_(), + interm_sig_(), + sec_(sox_create_effects_chain(&in_enc_, &out_enc_)) { + if (!sec_) { + throw std::runtime_error("Failed to create effect chain."); + } +} + +SoxEffectsChain::~SoxEffectsChain() { + if (sec_ != nullptr) { + sox_delete_effects_chain(sec_); + } +} + +void SoxEffectsChain::run() { + sox_flow_effects(sec_, NULL, NULL); +} + +void SoxEffectsChain::addInputTensor(TensorSignal* signal) { + in_sig_ = get_signalinfo(signal, "wav"); + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(get_tensor_input_handler())); + auto priv = static_cast(e->priv); + priv->signal = signal; + priv->index = 0; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error("Failed to add effect: input_tensor"); + } +} + +void SoxEffectsChain::addOutputBuffer( + std::vector* output_buffer) { + SoxEffect e(sox_create_effect(get_tensor_output_handler())); + static_cast(e->priv)->buffer = output_buffer; + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + throw std::runtime_error("Failed to add effect: output_tensor"); + } +} + +void SoxEffectsChain::addInputFile(sox_format_t* sf) { + in_sig_ = sf->signal; + interm_sig_ = in_sig_; + SoxEffect e(sox_create_effect(sox_find_effect("input"))); + char* opts[] = {(char*)sf}; + sox_effect_options(e, 1, opts); + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Failed to add effect: input " << sf->filename; + throw std::runtime_error(stream.str()); + } +} + +void SoxEffectsChain::addEffect(const std::vector effect) { + const auto num_args = effect.size(); + if (num_args == 0) { + throw std::runtime_error("Invalid argument: empty effect."); + } + const auto name = effect[0]; + if (UNSUPPORTED_EFFECTS.find(name) != UNSUPPORTED_EFFECTS.end()) { + std::ostringstream stream; + stream << "Unsupported effect: " << name; + throw std::runtime_error(stream.str()); + } + + SoxEffect e(sox_create_effect(sox_find_effect(name.c_str()))); + const auto num_options = num_args - 1; + + std::vector opts; + for (size_t i = 1; i < num_args; ++i) { + opts.push_back((char*)effect[i].c_str()); + } + if (sox_effect_options(e, num_options, num_options ? opts.data() : nullptr) != + SOX_SUCCESS) { + std::ostringstream stream; + stream << "Invalid effect option:"; + for (const auto& v : effect) { + stream << " " << v; + } + throw std::runtime_error(stream.str()); + } + + if (sox_add_effect(sec_, e, &interm_sig_, &in_sig_) != SOX_SUCCESS) { + std::ostringstream stream; + stream << "Failed to add effect: \"" << name; + for (size_t i = 1; i < num_args; ++i) { + stream << " " << effect[i]; + } + stream << "\""; + throw std::runtime_error(stream.str()); + } +} + +int64_t SoxEffectsChain::getOutputNumChannels() { + return interm_sig_.channels; +} + +int64_t SoxEffectsChain::getOutputSampleRate() { + return interm_sig_.rate; +} + +} // namespace sox_effects_chain +} // namespace torchaudio diff --git a/torchaudio/csrc/sox_effects_chain.h b/torchaudio/csrc/sox_effects_chain.h new file mode 100644 index 0000000000..9168e94121 --- /dev/null +++ b/torchaudio/csrc/sox_effects_chain.h @@ -0,0 +1,40 @@ +#ifndef TORCHAUDIO_SOX_EFFECTS_CHAIN_H +#define TORCHAUDIO_SOX_EFFECTS_CHAIN_H + +#include +#include +#include + +namespace torchaudio { +namespace sox_effects_chain { + +// Helper struct to safely close sox_effects_chain_t with handy methods +class SoxEffectsChain { + const sox_encodinginfo_t in_enc_; + const sox_encodinginfo_t out_enc_; + sox_signalinfo_t in_sig_; + sox_signalinfo_t interm_sig_; + sox_effects_chain_t* sec_; + + public: + explicit SoxEffectsChain( + sox_encodinginfo_t input_encoding, + sox_encodinginfo_t output_encoding); + SoxEffectsChain(const SoxEffectsChain& other) = delete; + SoxEffectsChain(const SoxEffectsChain&& other) = delete; + SoxEffectsChain& operator=(const SoxEffectsChain& other) = delete; + SoxEffectsChain& operator=(SoxEffectsChain&& other) = delete; + ~SoxEffectsChain(); + void run(); + void addInputTensor(torchaudio::sox_utils::TensorSignal* signal); + void addInputFile(sox_format_t* sf); + void addOutputBuffer(std::vector* output_buffer); + void addEffect(const std::vector effect); + int64_t getOutputNumChannels(); + int64_t getOutputSampleRate(); +}; + +} // namespace sox_effects_chain +} // namespace torchaudio + +#endif diff --git a/torchaudio/csrc/sox_io.cpp b/torchaudio/csrc/sox_io.cpp index 5d308027bb..2785c7910d 100644 --- a/torchaudio/csrc/sox_io.cpp +++ b/torchaudio/csrc/sox_io.cpp @@ -125,14 +125,12 @@ void save_audio_file( const c10::intrusive_ptr& signal, const double compression) { const auto tensor = signal->getTensor(); - const auto sample_rate = signal->getSampleRate(); const auto channels_first = signal->getChannelsFirst(); validate_input_tensor(tensor); const auto filetype = get_filetype(file_name); - const auto signal_info = - get_signalinfo(tensor, sample_rate, channels_first, filetype); + const auto signal_info = get_signalinfo(signal.get(), filetype); const auto encoding_info = get_encodinginfo(filetype, tensor.dtype(), compression); diff --git a/torchaudio/csrc/sox_utils.cpp b/torchaudio/csrc/sox_utils.cpp index c1fd8383a8..61eac6f306 100644 --- a/torchaudio/csrc/sox_utils.cpp +++ b/torchaudio/csrc/sox_utils.cpp @@ -5,6 +5,49 @@ namespace torchaudio { namespace sox_utils { +void set_seed(const int64_t seed) { + sox_get_globals()->ranqd1 = static_cast(seed); +} + +void set_verbosity(const int64_t verbosity) { + sox_get_globals()->verbosity = static_cast(verbosity); +} + +void set_use_threads(const bool use_threads) { + sox_get_globals()->use_threads = static_cast(use_threads); +} + +void set_buffer_size(const int64_t buffer_size) { + sox_get_globals()->bufsiz = static_cast(buffer_size); +} + +std::vector> list_effects() { + std::vector> effects; + for (const sox_effect_fn_t* fns = sox_get_effect_fns(); *fns; ++fns) { + const sox_effect_handler_t* handler = (*fns)(); + if (handler && handler->name) { + if (UNSUPPORTED_EFFECTS.find(handler->name) == + UNSUPPORTED_EFFECTS.end()) { + effects.emplace_back(std::vector{ + handler->name, + handler->usage ? std::string(handler->usage) : std::string("")}); + } + } + } + return effects; +} + +std::vector list_formats() { + std::vector formats; + for (const sox_format_tab_t* fns = sox_get_format_fns(); fns->fn; ++fns) { + for (const char* const* names = fns->fn()->names; *names; ++names) { + if (!strchr(*names, '/')) + formats.emplace_back(*names); + } + } + return formats; +} + TensorSignal::TensorSignal( torch::Tensor tensor_, int64_t sample_rate_, @@ -205,13 +248,13 @@ unsigned get_precision( } sox_signalinfo_t get_signalinfo( - const torch::Tensor& tensor, - const int64_t sample_rate, - const bool channels_first, + const TensorSignal* signal, const std::string filetype) { + auto tensor = signal->getTensor(); return sox_signalinfo_t{ - /*rate=*/static_cast(sample_rate), - /*channels=*/static_cast(tensor.size(channels_first ? 0 : 1)), + /*rate=*/static_cast(signal->getSampleRate()), + /*channels=*/ + static_cast(tensor.size(signal->getChannelsFirst() ? 0 : 1)), /*precision=*/get_precision(filetype, tensor.dtype()), /*length=*/static_cast(tensor.numel())}; } diff --git a/torchaudio/csrc/sox_utils.h b/torchaudio/csrc/sox_utils.h index 665187c840..7c0cf7b3f1 100644 --- a/torchaudio/csrc/sox_utils.h +++ b/torchaudio/csrc/sox_utils.h @@ -7,6 +7,25 @@ namespace torchaudio { namespace sox_utils { +//////////////////////////////////////////////////////////////////////////////// +// APIs for Python interaction +//////////////////////////////////////////////////////////////////////////////// + +/// Set sox global options +void set_seed(const int64_t seed); + +void set_verbosity(const int64_t verbosity); + +void set_use_threads(const bool use_threads); + +void set_buffer_size(const int64_t buffer_size); + +std::vector> list_effects(); + +std::vector list_formats(); + +/// Class for exchanging signal infomation (tensor + meta data) between +/// C++ and Python for read/write operation. struct TensorSignal : torch::CustomClassHolder { torch::Tensor tensor; int64_t sample_rate; @@ -22,6 +41,13 @@ struct TensorSignal : torch::CustomClassHolder { bool getChannelsFirst() const; }; +//////////////////////////////////////////////////////////////////////////////// +// Utilities for sox_io / sox_effects implementations +//////////////////////////////////////////////////////////////////////////////// + +const std::unordered_set UNSUPPORTED_EFFECTS = + {"input", "output", "spectrogram", "noiseprof", "noisered", "splice"}; + /// helper class to automatically close sox_format_t* struct SoxFormat { explicit SoxFormat(sox_format_t* fd) noexcept; @@ -84,9 +110,7 @@ const std::string get_filetype(const std::string path); /// Get sox_signalinfo_t for passing a torch::Tensor object. sox_signalinfo_t get_signalinfo( - const torch::Tensor& tensor, - const int64_t sample_rate, - const bool channels_first, + const TensorSignal* signal, const std::string filetype); /// Get sox_encofinginfo_t for saving audoi file diff --git a/torchaudio/sox_effects/__init__.py b/torchaudio/sox_effects/__init__.py index 507dc5c3af..d9650173c5 100644 --- a/torchaudio/sox_effects/__init__.py +++ b/torchaudio/sox_effects/__init__.py @@ -3,6 +3,8 @@ init_sox_effects, shutdown_sox_effects, effect_names, + apply_effects_tensor, + apply_effects_file, SoxEffect, SoxEffectsChain, ) diff --git a/torchaudio/sox_effects/sox_effects.py b/torchaudio/sox_effects/sox_effects.py index 0aee312126..0e5591a06b 100644 --- a/torchaudio/sox_effects/sox_effects.py +++ b/torchaudio/sox_effects/sox_effects.py @@ -7,6 +7,8 @@ module_utils as _mod_utils, misc_ops as _misc_ops, ) +from torchaudio.utils.sox_utils import list_effects + if _mod_utils.is_module_available('torchaudio._torchaudio'): from torchaudio import _torchaudio @@ -52,7 +54,128 @@ def effect_names() -> List[str]: Example >>> EFFECT_NAMES = torchaudio.sox_effects.effect_names() """ - return torch.ops.torchaudio.sox_effects_list_effects() + return list(list_effects().keys()) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def apply_effects_tensor( + tensor: torch.Tensor, + sample_rate: int, + effects: List[List[str]], + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + """Apply sox effects to given Tensor + + Args: + tensor (torch.Tensor): Input 2D Tensor. + sample_rate (int): Sample rate + effects (List[List[str]]): List of effects. + channels_first (bool): Indicates if the input Tensor's dimension is + ``[channels, time]`` or ``[time, channels]`` + + Returns: + Tuple[torch.Tensor, int]: Resulting Tensor and sample rate. + The resulting Tensor has the same ``dtype`` as the input Tensor, and + the same channels order. The shape of the Tensor can be different based on the + effects applied. Sample rate can also be different based on the effects applied. + + Notes: + This function works in the way very similar to ``sox`` command, however there are slight + differences. For example, ``sox`` commnad adds certain effects automatically (such as + ``rate`` effect after ``speed`` and ``pitch`` and other effects), but this function does + only applies the given effects. (Therefore, to actually apply ``speed`` effect, you also + need to give ``rate`` effect with desired sampling rate.) + + Examples: + >>> # Defines the effects to apply + >>> effects = [ + ... ['gain', '-n'], # normalises to 0dB + ... ['pitch', '5'], # 5 cent pitch shift + ... ['rate', '8000'], # resample to 8000 Hz + ... ] + >>> # Generate pseudo wave: + >>> # normalized, channels first, 2ch, sampling rate 16000, 1 second + >>> sample_rate = 16000 + >>> waveform = 2 * torch.rand([2, sample_rate * 1]) - 1 + >>> waveform.shape + torch.Size([2, 16000]) + >>> waveform + tensor([[ 0.3138, 0.7620, -0.9019, ..., -0.7495, -0.4935, 0.5442], + [-0.0832, 0.0061, 0.8233, ..., -0.5176, -0.9140, -0.2434]]) + >>> # Apply effects + >>> waveform, sample_rate = apply_effects_tensor( + ... wave_form, sample_rate, effects, channels_first=True) + >>> # The new waveform is sampling rate 8000, 1 second. + >>> # normalization and channel order are preserved + >>> waveform.shape + torch.Size([2, 8000]) + >>> waveform + tensor([[ 0.5054, -0.5518, -0.4800, ..., -0.0076, 0.0096, -0.0110], + [ 0.1331, 0.0436, -0.3783, ..., -0.0035, 0.0012, 0.0008]]) + >>> sample_rate + 8000 + """ + in_signal = torch.classes.torchaudio.TensorSignal(tensor, sample_rate, channels_first) + out_signal = torch.ops.torchaudio.sox_effects_apply_effects_tensor(in_signal, effects) + return out_signal.get_tensor(), out_signal.get_sample_rate() + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def apply_effects_file( + path: str, + effects: List[List[str]], + normalize: bool = True, + channels_first: bool = True, +) -> Tuple[torch.Tensor, int]: + """Apply sox effects to the audio file and load the resulting data as Tensor + + Args: + path (str): Path to the audio file. + effects (List[List[str]]): List of effects. + normalize (bool): When ``True``, this function always return ``float32``, and sample values are + normalized to ``[-1.0, 1.0]``. If input file is integer WAV, giving ``False`` will change + the resulting Tensor type to integer type. This argument has no effect for formats other + than integer WAV type. + channels_first (bool): When True, the returned Tensor has dimension ``[channel, time]``. + Otherwise, the returned Tensor's dimension is ``[time, channel]``. + + Returns: + Tuple[torch.Tensor, int]: Resulting Tensor and sample rate. + If ``normalize=True``, the resulting Tensor is always ``float32`` type. + If ``normalize=False`` and the input audio file is of integer WAV file, then the + resulting Tensor has corresponding integer type. (Note 24 bit integer type is not supported) + If ``channels_first=True``, the resulting Tensor has dimension ``[channel, time]``, + otherwise ``[time, channel]``. + + Notes: + This function works in the way very similar to ``sox`` command, however there are slight + differences. For example, ``sox`` commnad adds certain effects automatically (such as + ``rate`` effect after ``speed``, ``pitch`` etc), but this function only applies the given + effects. Therefore, to actually apply ``speed`` effect, you also need to give ``rate`` + effect with desired sampling rate, because internally, ``speed`` effects only alter sampling + rate and leave samples untouched. + + Examples: + >>> # Defines the effects to apply + >>> effects = [ + ... ['gain', '-n'], # normalises to 0dB + ... ['pitch', '5'], # 5 cent pitch shift + ... ['rate', '8000'], # resample to 8000 Hz + ... ] + >>> # Apply effects and load data with channels_first=True + >>> waveform, sample_rate = apply_effects_file("data.wav", effects, channels_first=True) + >>> waveform.shape + torch.Size([2, 8000]) + >>> waveform + tensor([[ 5.1151e-03, 1.8073e-02, 2.2188e-02, ..., 1.0431e-07, + -1.4761e-07, 1.8114e-07], + [-2.6924e-03, 2.1860e-03, 1.0650e-02, ..., 6.4122e-07, + -5.6159e-07, 4.8103e-07]]) + >>> sample_rate + 8000 + """ + signal = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first) + return signal.get_tensor(), signal.get_sample_rate() @_mod_utils.requires_module('torchaudio._torchaudio') diff --git a/torchaudio/utils/__init__.py b/torchaudio/utils/__init__.py new file mode 100644 index 0000000000..bc11f86893 --- /dev/null +++ b/torchaudio/utils/__init__.py @@ -0,0 +1,9 @@ +from . import ( + sox_utils, +) + +from torchaudio._internal import module_utils as _mod_utils + + +if _mod_utils.is_module_available('torchaudio._torchaudio'): + sox_utils.set_verbosity(1) diff --git a/torchaudio/utils/sox_utils.py b/torchaudio/utils/sox_utils.py new file mode 100644 index 0000000000..d9901472ec --- /dev/null +++ b/torchaudio/utils/sox_utils.py @@ -0,0 +1,84 @@ +from typing import List, Dict + +import torch + +from torchaudio._internal import ( + module_utils as _mod_utils, +) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def set_seed(seed: int): + """Set libsox's PRNG + + Args: + seed (int): seed value. valid range is int32. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_seed(seed) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def set_verbosity(verbosity: int): + """Set libsox's verbosity + + Args: + verbosity (int): Set verbosity level of libsox. + 1: failure messages + 2: warnings + 3: details of processing + 4-6: increasing levels of debug messages + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_verbosity(verbosity) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def set_buffer_size(buffer_size: int): + """Set buffer size for sox effect chain + + Args: + buffer_size (int): Set the size in bytes of the buffers used for processing audio. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_buffer_size(buffer_size) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def set_use_threads(use_threads: bool): + """Set multithread option for sox effect chain + + Args: + use_threads (bool): When True, enables libsox's parallel effects channels processing. + To use mutlithread, the underlying libsox has to be compiled with OpenMP support. + + See Also: + http://sox.sourceforge.net/sox.html + """ + torch.ops.torchaudio.sox_utils_set_use_threads(use_threads) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def list_effects() -> Dict[str, str]: + """List the available sox effect names + + Returns: + Dict[str, str]: Mapping from "effect name" to "usage" + """ + return dict(torch.ops.torchaudio.sox_utils_list_effects()) + + +@_mod_utils.requires_module('torchaudio._torchaudio') +def list_formats() -> List[str]: + """List the supported audio formats + + Returns: + List[str]: List of supported audio formats + """ + return torch.ops.torchaudio.sox_utils_list_formats()