Skip to content

Commit 080cd30

Browse files
authored
Fix incorrect extension parsing in sox_io_backend.save(#885)
* Fix incorrect extension parsing in sox_io_backend.save * Add tests for compression=None
1 parent 2205cc9 commit 080cd30

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

test/torchaudio_unittest/sox_io_backend/save_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def test_multiple_channels(self, dtype, num_channels):
235235
@parameterized.expand(list(itertools.product(
236236
[8000, 16000],
237237
[1, 2],
238-
[-4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
238+
[None, -4.2, -0.2, 0, 0.2, 96, 128, 160, 192, 224, 256, 320],
239239
)), name_func=name_func)
240240
def test_mp3(self, sample_rate, num_channels, bit_rate):
241241
"""`sox_io_backend.save` can save mp3 format."""
@@ -254,7 +254,7 @@ def test_mp3_large(self, sample_rate, num_channels, bit_rate):
254254
@parameterized.expand(list(itertools.product(
255255
[8000, 16000],
256256
[1, 2],
257-
list(range(9)),
257+
[None] + list(range(9)),
258258
)), name_func=name_func)
259259
def test_flac(self, sample_rate, num_channels, compression_level):
260260
"""`sox_io_backend.save` can save flac format."""
@@ -273,7 +273,7 @@ def test_flac_large(self, sample_rate, num_channels, compression_level):
273273
@parameterized.expand(list(itertools.product(
274274
[8000, 16000],
275275
[1, 2],
276-
[-1, 0, 1, 2, 3, 3.6, 5, 10],
276+
[None, -1, 0, 1, 2, 3, 3.6, 5, 10],
277277
)), name_func=name_func)
278278
def test_vorbis(self, sample_rate, num_channels, quality_level):
279279
"""`sox_io_backend.save` can save vorbis format."""

torchaudio/backend/sox_io_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ def save(
159159
See the detail at http://sox.sourceforge.net/soxformat.html.
160160
"""
161161
if compression is None:
162-
ext = str(filepath)[-3:].lower()
162+
ext = str(filepath).split('.')[-1].lower()
163163
if ext in ['wav', 'sph']:
164164
compression = 0.
165165
elif ext == 'mp3':

0 commit comments

Comments
 (0)