diff --git a/test/torchaudio_unittest/sox_io_backend/save_test.py b/test/torchaudio_unittest/sox_io_backend/save_test.py index f7d4da3931..eadce92f96 100644 --- a/test/torchaudio_unittest/sox_io_backend/save_test.py +++ b/test/torchaudio_unittest/sox_io_backend/save_test.py @@ -235,7 +235,7 @@ def test_multiple_channels(self, dtype, num_channels): @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], - [-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], + [None, -4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320], )), name_func=name_func) def test_mp3(self, sample_rate, num_channels, bit_rate): """`sox_io_backend.save` can save mp3 format.""" @@ -254,7 +254,7 @@ def test_mp3_large(self, sample_rate, num_channels, bit_rate): @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], - list(range(9)), + [None] + list(range(9)), )), name_func=name_func) def test_flac(self, sample_rate, num_channels, compression_level): """`sox_io_backend.save` can save flac format.""" @@ -273,7 +273,7 @@ def test_flac_large(self, sample_rate, num_channels, compression_level): @parameterized.expand(list(itertools.product( [8000, 16000], [1, 2], - [-1, 0, 1, 2, 3, 3.6, 5, 10], + [None, -1, 0, 1, 2, 3, 3.6, 5, 10], )), name_func=name_func) def test_vorbis(self, sample_rate, num_channels, quality_level): """`sox_io_backend.save` can save vorbis format.""" diff --git a/torchaudio/backend/sox_io_backend.py b/torchaudio/backend/sox_io_backend.py index 9c0695d723..31e69c443e 100644 --- a/torchaudio/backend/sox_io_backend.py +++ b/torchaudio/backend/sox_io_backend.py @@ -159,7 +159,7 @@ def save( See the detail at http://sox.sourceforge.net/soxformat.html. """ if compression is None: - ext = str(filepath)[-3:].lower() + ext = str(filepath).split('.')[-1].lower() if ext in ['wav', 'sph']: compression = 0. elif ext == 'mp3':