diff --git a/test/torchaudio_unittest/sox_effect/sox_effect_test.py b/test/torchaudio_unittest/sox_effect/sox_effect_test.py index 27740ca576..c5b57eaf5f 100644 --- a/test/torchaudio_unittest/sox_effect/sox_effect_test.py +++ b/test/torchaudio_unittest/sox_effect/sox_effect_test.py @@ -1,4 +1,5 @@ import itertools +from pathlib import Path from torchaudio import sox_effects from parameterized import parameterized @@ -104,7 +105,7 @@ def test_apply_no_effect(self, dtype, sample_rate, num_channels, channels_first) load_params("sox_effect_test_args.json"), name_func=lambda f, i, p: f'{f.__name__}_{i}_{p.args[0]["effects"][0][0]}', ) - def test_apply_effects(self, args): + def test_apply_effects_str(self, args): """`apply_effects_file` should return identical data as sox command""" dtype = 'int32' channels_first = True @@ -127,6 +128,29 @@ def test_apply_effects(self, args): assert sr == expected_sr self.assertEqual(found, expected) + def test_apply_effects_path(self): + """`apply_effects_file` should return identical data as sox command when file path is given as a Path Object""" + dtype = 'int32' + channels_first = True + effects = [["hilbert"]] + num_channels = 2 + input_sr = 8000 + output_sr = 8000 + + input_path = self.get_temp_path('input.wav') + reference_path = self.get_temp_path('reference.wav') + data = get_wav_data(dtype, num_channels, channels_first=channels_first) + save_wav(input_path, data, input_sr, channels_first=channels_first) + sox_utils.run_sox_effect( + input_path, reference_path, effects, output_sample_rate=output_sr) + + expected, expected_sr = load_wav(reference_path) + found, sr = sox_effects.apply_effects_file( + Path(input_path), effects, normalize=False, channels_first=channels_first) + + assert sr == expected_sr + self.assertEqual(found, expected) + @skipIfNoExtension class TestFileFormats(TempDirMixin, PytorchTestCase): diff --git a/torchaudio/sox_effects/sox_effects.py b/torchaudio/sox_effects/sox_effects.py index 8b50d19aa7..2fb4694d31 100644 --- a/torchaudio/sox_effects/sox_effects.py +++ b/torchaudio/sox_effects/sox_effects.py @@ -1,4 +1,5 @@ -from typing import List, Tuple +from typing import List, Tuple, Union +from pathlib import Path import torch @@ -169,7 +170,8 @@ def apply_effects_file( rate and leave samples untouched. Args: - path (str): Path to the audio file. + path (str or pathlib.Path): Path to the audio file. This function also handles ``pathlib.Path`` objects, but is + annotated as ``str`` for TorchScript compiler compatibility. effects (List[List[str]]): List of effects. normalize (bool): When ``True``, this function always return ``float32``, and sample values are @@ -247,5 +249,7 @@ def apply_effects_file( >>> for batch in loader: >>> pass """ + # Get string representation of 'path' in case Path object is passed + path = str(path) signal = torch.ops.torchaudio.sox_effects_apply_effects_file(path, effects, normalize, channels_first) return signal.get_tensor(), signal.get_sample_rate()